# Core libraries
import pandas as pd
import numpy as np
import joblib
from collections import defaultdict
import torch
# Sklearn imports
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
from sklearn.metrics import auc, classification_report, roc_auc_score, confusion_matrix, precision_recall_curve, matthews_corrcoef, ConfusionMatrixDisplay
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.utils import class_weight
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
classification_report, roc_auc_score, matthews_corrcoef
)
from sklearn.metrics import roc_curve
# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Concatenate, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import Precision, Recall, AUC
from tensorflow.keras import backend as K
import torch
#Tabnet
from pytorch_tabnet.tab_model import TabNetClassifier
# Imbalanced Learning
from imblearn.over_sampling import SMOTENC
# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
# Warnings
from warnings import filterwarnings
filterwarnings('ignore')
from collections import Counter
This notebook contains two types of models that try to predict the follow
- Failure event
- Maintainence Needs
For each objective a model in built using Random forest an ensemble ML model and TabNet that is built using NN. Therefore there are 4 models in this notebooks. For detailed analysis of models read the report. To make the information redundant, I have only focused on model comparisions of best accuracy in K-fold. I have focued on explaining the models here and the evaluations in the report.
Model explanations have been done only for Failure event, whereas for Maintainence needs only the model and results are displayed as the same model is replicated but different dataset labels are passed.
Failure Event¶
Data Loading and Visualization¶
df = pd.read_csv('military_asset_maintenance_data.csv')
df.head()
| Asset_ID | Asset_Type | Age_of_Asset | Usage_Hours | Temperature | Pressure | Fuel_Consumption | Vibration_Levels | Humidity | Location | Maintenance_History | Failure_Event | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | Ship | 9 | 5740 | 44.968711 | 137.776053 | 49.336583 | 0.043240 | 59.341921 | Temperate | 2 | 0 |
| 1 | 2 | Aircraft | 16 | 8326 | 74.398522 | 37.204956 | 95.947730 | 0.171963 | 27.762498 | Tropical | 2 | 0 |
| 2 | 3 | Ship | 11 | 2667 | 58.640931 | 146.077726 | 72.209295 | 0.285799 | 48.473243 | Desert | 1 | 0 |
| 3 | 4 | Ship | 14 | 8436 | 73.961007 | 141.785229 | 33.231257 | 0.540367 | 25.760009 | Temperate | 1 | 0 |
| 4 | 5 | Aircraft | 8 | 6835 | 45.759227 | 56.240621 | 38.631240 | 0.204255 | 60.852518 | Tropical | 2 | 0 |
#amount of null values
df.isnull().sum()
Asset_ID 0 Asset_Type 0 Age_of_Asset 0 Usage_Hours 0 Temperature 0 Pressure 0 Fuel_Consumption 0 Vibration_Levels 0 Humidity 0 Location 0 Maintenance_History 0 Failure_Event 0 dtype: int64
df.count()
Asset_ID 50000 Asset_Type 50000 Age_of_Asset 50000 Usage_Hours 50000 Temperature 50000 Pressure 50000 Fuel_Consumption 50000 Vibration_Levels 50000 Humidity 50000 Location 50000 Maintenance_History 50000 Failure_Event 50000 dtype: int64
df.Asset_ID.value_counts().unique()
array([1])
We can see that the dataset contains no null and duplicate values values
df['Failure_Event'].value_counts()
Failure_Event 0 42436 1 7564 Name: count, dtype: int64
df['Failure_Event'].value_counts().plot(
kind='pie',
autopct='%1.1f%%',
figsize=(4, 4),
labels=['No Failure', 'Failure'],
startangle=90,
explode=(0, 0.1)
)
plt.title("Failure Event Distribution")
plt.ylabel("") # Remove the default y-axis label
plt.show()
df.value_counts('Asset_Type')
Asset_Type Aircraft 16709 Vehicle 16682 Ship 16609 Name: count, dtype: int64
df[["Asset_Type","Failure_Event"]].value_counts()
Asset_Type Failure_Event Aircraft 0 14204 Vehicle 0 14127 Ship 0 14105 Vehicle 1 2555 Aircraft 1 2505 Ship 1 2504 Name: count, dtype: int64
The data set has an alarminly high imbalanced dataset of 2504 of failures per asset as compared to 14204 of no failure event.
df[["Asset_Type", "Failure_Event"]].value_counts().unstack().plot(kind='bar', stacked=True, figsize=(6, 4))
plt.title("Failure Events by Asset Type")
plt.xlabel("Asset Type")
plt.ylabel("Count")
plt.legend(title="Failure Event", labels=["No Failure", "Failure"])
plt.show()
numeric_columns = df.drop(columns=["Asset_Type", "Location", "Failure_Event"]).columns
# Set up the plotting grid
plt.figure(figsize=(16, 12))
for i, col in enumerate(numeric_columns, 1):
plt.subplot(3, 3, i)
sns.histplot(df[col], kde=True)
plt.title(f'Distribution of {col}')
plt.tight_layout()
plt.show()
Asset Information
Age_of_Asset: Age ranges mostly between 1 to 18 years. There's a relatively even spread, but a slight dip around ages 6–10 could suggest fewer assets in that age range or maintenance/replacement around that period.Usage_Hours: Fairly uniform distribution up to ~10,000 hours. Indicates that usage varies widely, which is good for training models to learn from both low and high-usage cases.Sensor and Operational Data: Temperature, Pressure, Fuel_Consumption, Vibration_Levels, Humidity
These features show approximately uniform distributions, with minor fluctuations. This suggests good feature variability, which is beneficial for training — the model has exposure to a wide range of conditions. No clear skewness or major outliers are visible, indicating data is already well-cleaned.
# add code to ingnore categorical
df_corr = df.drop(columns=["Asset_Type", "Location", "Failure_Event"])
# Compute the correlation matrix
corr_matrix = df_corr.corr()
# Set up the plot
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", square=True, cbar_kws={"shrink": .8})
plt.title("Correlation Matrix of Features", fontsize=16)
plt.tight_layout()
plt.show()
No multicolinearity amoung the variables.
#box plots
plt.figure(figsize=(16, 12))
for i, col in enumerate(numeric_columns, 1):
plt.subplot(3, 3, i)
sns.boxplot(x='Failure_Event', y=col, data=df)
plt.title(f'Box Plot of {col} by Failure Event')
plt.tight_layout()
plt.show()
# Check for outliers
outlier_threshold = 3
outlier_columns = numeric_columns
outliers = {}
for col in outlier_columns:
z_scores = (df[col] - df[col].mean()) / df[col].std()
outliers[col] = df[np.abs(z_scores) > outlier_threshold][col]
print(f"Outliers in {col}:")
print(outliers[col])
Outliers in Asset_ID: Series([], Name: Asset_ID, dtype: int64) Outliers in Age_of_Asset: Series([], Name: Age_of_Asset, dtype: int64) Outliers in Usage_Hours: Series([], Name: Usage_Hours, dtype: int64) Outliers in Temperature: Series([], Name: Temperature, dtype: float64) Outliers in Pressure: Series([], Name: Pressure, dtype: float64) Outliers in Fuel_Consumption: Series([], Name: Fuel_Consumption, dtype: float64) Outliers in Vibration_Levels: Series([], Name: Vibration_Levels, dtype: float64) Outliers in Humidity: Series([], Name: Humidity, dtype: float64) Outliers in Maintenance_History: Series([], Name: Maintenance_History, dtype: int64)
Most features show some overlap between failure and non-failure events, but a few features exhibit distinct shifts or spread differences, indicating their predictive potential.
Older assets and those with higher usage hours tend to experience more failures.
Higher vibration levels and elevated temperatures are noticeably associated with failure events, indicating they are strong predictors of potential issues.
Fuel consumption and maintenance history show slight increases in failed assets, suggesting possible early warning signs.
Pressure and humidity show minimal differences, indicating limited impact individually.
df[["Location", "Failure_Event"]].value_counts().unstack().plot(kind='bar', stacked=True, figsize=(6, 4))
plt.title("Failure Events by Asset Type")
plt.xlabel("Asset Type")
plt.ylabel("Count")
plt.legend(title="Failure Event", labels=["No Failure", "Failure"])
<matplotlib.legend.Legend at 0x33aa0bdc0>
Data preparation¶
Mapping categorical variables¶
# mapping columns
categorical_cols = ['Asset_Type', 'Location']
# Dictionary to store mappings
category_mappings = {}
for col in categorical_cols:
unique_values = df[col].unique()
mapping = {val: idx for idx, val in enumerate(unique_values)}
df[col] = df[col].map(mapping)
category_mappings[col] = mapping
for col, mapping in category_mappings.items():
print(f"{col} mapping: {mapping}")
Asset_Type mapping: {'Ship': 0, 'Aircraft': 1, 'Vehicle': 2}
Location mapping: {'Temperate': 0, 'Tropical': 1, 'Desert': 2}
Feature Engineering for to capture complex relationships in data¶
df['Thermal_Stress'] = df['Usage_Hours'] * df['Temperature']
df['Age_Vibration_Interaction'] = df['Age_of_Asset'] * df['Vibration_Levels']
df['Fuel_Efficiency'] = df['Fuel_Consumption'] / (df['Usage_Hours'] + 1e-5) # avoid division by 0
df['Pressure_Temp_Interaction'] = df['Pressure'] * df['Temperature']
df['Operational_Stress_Index'] = (df['Vibration_Levels'] + df['Pressure'] + df['Temperature'] + df['Usage_Hours'])/ (df['Age_of_Asset'] + 1e-5)
df.drop(columns=['Asset_ID'], inplace=True)
df.head()
| Asset_Type | Age_of_Asset | Usage_Hours | Temperature | Pressure | Fuel_Consumption | Vibration_Levels | Humidity | Location | Maintenance_History | Failure_Event | Thermal_Stress | Age_Vibration_Interaction | Fuel_Efficiency | Pressure_Temp_Interaction | Operational_Stress_Index | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 9 | 5740 | 44.968711 | 137.776053 | 49.336583 | 0.043240 | 59.341921 | 0 | 2 | 0 | 258120.403835 | 0.389162 | 0.008595 | 6195.611574 | 658.086825 |
| 1 | 1 | 16 | 8326 | 74.398522 | 37.204956 | 95.947730 | 0.171963 | 27.762498 | 1 | 2 | 0 | 619442.090624 | 2.751404 | 0.011524 | 2767.993702 | 527.360635 |
| 2 | 0 | 11 | 2667 | 58.640931 | 146.077726 | 72.209295 | 0.285799 | 48.473243 | 2 | 1 | 0 | 156395.363951 | 3.143794 | 0.027075 | 8566.133914 | 261.091077 |
| 3 | 0 | 14 | 8436 | 73.961007 | 141.785229 | 33.231257 | 0.540367 | 25.760009 | 0 | 1 | 0 | 623935.057788 | 7.565140 | 0.003939 | 10486.578333 | 618.020030 |
| 4 | 1 | 8 | 6835 | 45.759227 | 56.240621 | 38.631240 | 0.204255 | 60.852518 | 1 | 2 | 0 | 312764.319305 | 1.634040 | 0.005652 | 2573.527343 | 867.149429 |
Feature Engineering Rationale¶
These features are designed to:
- Ccapture interactions between variables that may not be obvious from individual inputs.
- Reflect real-world mechanical stress patterns.
- Improve model predictive performance by injecting domain knowledge into the data.
Thermal_Stress = Usage_Hours * Temperature
- Captures the cumulative thermal load an asset experiences.
- Reflects extended usage in high-temperature conditions, which can accelerate wear and tear.
Age_Vibration_Interaction = Age_of_Asset * Vibration_Levels
- Combines mechanical wear (vibration) with asset aging.
- Older assets with high vibration levels are more prone to failure, making this an important risk indicator.
Fuel_Efficiency = Fuel_Consumption / (Usage_Hours + 1e-5)
- Represents how efficiently an asset consumes fuel relative to its operational time.
- Declining efficiency may signal engine or system degradation, indicating the need for maintenance.
Pressure_Temp_Interaction = Pressure * Temperature
- Models the combined stress of internal pressure and temperature.
- High values could indicate overheating or internal system strain, which may lead to failure.
Operational_Stress_Index = (Vibration_Levels + Pressure + Temperature + Usage_Hours) / (Age_of_Asset + 1e-5)
- A composite score that reflects the overall operational stress experienced by an asset.
- Normalized by asset age to provide a measure of how much load the asset is handling relative to its lifecycle.
- Higher values may suggest overuse or abnormal stress, potentially increasing failure risk.
These engineered features introduce meaningful interactions that are likely to improve model performance, particularly in scenarios involving military asset maintenance, where wear-and-tear dynamics are complex and interdependent.
corr_matrix = df.corr()
# Set up the plot
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", square=True, cbar_kws={"shrink": .8})
plt.title("Correlation Matrix of Features", fontsize=16)
plt.tight_layout()
plt.show()
Engineered features are mathematically and contextually justified, and their correlations confirm their relevance.
Since no feature has a strong linear relationship with Failure_Event, feature interactions and ensemble models are critical.
Some features (like Humidity, Location, and Asset_Type) have near-zero correlation with most variables, which may warrant further analysis or dimensionality reduction.
Data Pre Processing¶
# Features and target
X = df.drop(columns=['Failure_Event'])
y = df['Failure_Event']
# Define categorical and numerical features
cat_features = ['Asset_Type', 'Location', 'Maintenance_History']
num_features = [col for col in X.columns if col not in cat_features]
# Train-test split
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# --- Apply SMOTE to balance the classes in the training data ---
categorical_indices = [X_train.columns.get_loc(col) for col in cat_features]
print("Before SMOTE:", Counter(y_train))
# Apply SMOTENC (for categorical and continuous features)
smote_nc = SMOTENC(
categorical_features=categorical_indices,
sampling_strategy=0.5, # Make minority class 50% the size of the majority class
random_state=42
)
X_train_balanced, y_train_balanced = smote_nc.fit_resample(X_train, y_train)
# Convert back to DataFrame for further processing
X_train_final = pd.DataFrame(X_train_balanced, columns=X_train.columns)
y_train_final = pd.Series(y_train_balanced, name='Failure_Event')
print("After SMOTE:", Counter(y_train_balanced))
# --- Scale numeric features after SMOTE ---
scaler = StandardScaler()
# Separate numerical columns for scaling
X_train_num = X_train_final[num_features]
X_val_num = X_val[num_features]
X_train_num_scaled = scaler.fit_transform(X_train_num)
X_val_num_scaled = scaler.transform(X_val_num)
# Convert back to DataFrame
X_train_num_scaled_df = pd.DataFrame(X_train_num_scaled, columns=num_features).reset_index(drop=True)
X_val_num_scaled_df = pd.DataFrame(X_val_num_scaled, columns=num_features).reset_index(drop=True)
# Prepare final training and validation sets by combining scaled numeric and encoded categorical features
X_train_final = pd.concat([X_train_num_scaled_df, X_train_final[cat_features].reset_index(drop=True)], axis=1)
X_val_final = pd.concat([X_val_num_scaled_df, X_val[cat_features].reset_index(drop=True)], axis=1)
Before SMOTE: Counter({0: 33949, 1: 6051})
After SMOTE: Counter({0: 33949, 1: 16974})
Feature Selection
- Target variable
Failure_Eventwas separated from the features. - Categorical features:
Asset_Type,Location,Maintenance_History. - Remaining columns treated as numerical features.
- Target variable
Train-Test Split
- Data was split 80/20 using stratified sampling to maintain class distribution.
Class Imbalance Handling
- Applied
SMOTENCto oversample the minority class (Failure_Event = 1) while handling both categorical and numerical data. - Sampling strategy increased the minority class to 50% of the majority. This will allow me to use class weights.
- Applied
Feature Scaling
- Numerical features were scaled using
StandardScalerto normalize distributions. - Categorical features were preserved without scaling.
- Numerical features were scaled using
Final Dataset Preparation
- Scaled numerical and raw categorical features were combined to form
X_train_finalandX_val_final.
- Scaled numerical and raw categorical features were combined to form
Even after applying SMOTENC, class imbalance may still persist slightly.
Using class weights in the model ensures that:
- The algorithm pays more attention to the minority class (failures), which is critical in high-risk domains like military maintenance.
- It compensates for any remaining imbalance, improving the model's sensitivity to rare but important failure events.
- It works synergistically with SMOTENC — SMOTENC balances the data distribution, while class weights adjust the loss function during training.
NOTE: I have experimented using only class weights without any SMOTENC, but the model didn't predict the minority class. I tried using SMOTENC without class weights, the model behaved the same. Thus I have used SMOTENC with minority class oversampled to consitituent only 50% of majority class and used class weights where required.
Random Forest¶
We will first use grid search to find the best parameters for the Random Forest model. The model will then be trained using the k-fold cross validation.
Grid Search for best parameters¶
#Set up a Random Forest with basic pruning options
rf = RandomForestClassifier(random_state=42, n_jobs=-1, class_weight='balanced')
#Grid Search for best hyperparameters including pruning-related ones
param_grid = {
'n_estimators': [100],
'max_depth': [3, 5, 7, None], # Control tree size (pruning)
'min_samples_split': [2, 5, 10], # Prevent overgrowth
'min_samples_leaf': [1, 2, 4],
'max_features': ['sqrt', 'log2']
}
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=3, scoring='roc_auc', verbose=1)
grid_search.fit(X_train_balanced, y_train_balanced) # using unscalled data for Random Forest
#Get best model and evaluate
best_rf_F = grid_search.best_estimator_
y_pred = best_rf_F.predict(X_val)
y_prob = best_rf_F.predict_proba(X_val)[:, 1]
print("Best Params:", grid_search.best_params_)
print("AUC-ROC:", roc_auc_score(y_val, y_prob))
print(classification_report(y_val, y_pred, digits=4))
mcc = matthews_corrcoef(y_val, y_pred)
print(f"MCC: {mcc:.4f}")
#Plot one of the trees in the forest
plt.figure(figsize=(30, 20))
plot_tree(best_rf_F.estimators_[0],
feature_names=X.columns,
class_names=['Class 0', 'Class 1'],
filled=True,
rounded=True,
max_depth=3,
fontsize=10) # Only show top 3 levels for clarity
plt.title("Random Forest - Tree Visualization")
plt.show()
joblib.dump(best_rf_F, 'Failure_event_random_forest_model.joblib')
Fitting 3 folds for each of 72 candidates, totalling 216 fits
Best Params: {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}
AUC-ROC: 0.5029605171191802
precision recall f1-score support
0 0.8495 0.9390 0.8920 8487
1 0.1632 0.0668 0.0947 1513
accuracy 0.8070 10000
macro avg 0.5063 0.5029 0.4934 10000
weighted avg 0.7456 0.8070 0.7714 10000
MCC: 0.0085
['Failure_event_random_forest_model.joblib']
Random Forest Classifier with Pruning and Hyperparameter Tuning¶
Model Setup¶
A RandomForestClassifier is initialized with:
class_weight='balanced'to handle class imbalance by penalizing the majority class during training.random_state=42ensures reproducibility.n_jobs=-1enables parallel computation for faster training.
Hyperparameter Tuning using GridSearchCV¶
A grid search is performed over key pruning-related hyperparameters:
n_estimators: Number of trees in the forest (fixed at 100).max_depth: Maximum depth of each tree to prevent overfitting.min_samples_split: Minimum samples required to split a node (controls granularity).min_samples_leaf: Minimum samples at a leaf node (avoids overly specific rules).max_features: Number of features to consider when splitting a node (sqrt,log2help reduce variance).
Scoring Metric: roc_auc
Cross-validation: 3-fold (cv=3)
Note: X_train_balanced is used without scaling, as Random Forests are scale-invariant.
Model Evaluation¶
The best model from GridSearchCV is used to predict on the validation set:
AUC-ROCis computed to measure overall classification performance.- A full
classification_reportis printed (precision, recall, F1-score). Matthews Correlation Coefficient (MCC)is calculated to provide a balanced metric even with class imbalance.
Visualizing One Decision Tree¶
A single tree (first in the ensemble) is visualized with:
max_depth=3for clarity.- Feature names and class labels included.
- Colors and shapes enhance interpretability.
Random Forest is an ensemble, but visualizing one tree helps understand individual decision paths.
Model Saving¶
The trained best model is saved using joblib for future inference:
joblib.dump(best_rf_F, 'Failure_event_random_forest_model.joblib')
Random Forest Tree Interpretation from Random Forest (Top 3 Levels)
Key Features Used for Splitting
Vibration_Levels (Root Node)
- Most important feature at the root.
- Lower vibration values are associated with higher likelihood of failure (Class 1).
Age_Vibration_Interaction
- Combines asset age and vibration to capture compound degradation.
- Lower interaction values (younger assets with some vibration) lean toward Class 0.
Pressure and Operational_Stress_Index
- Further refine splits based on operational intensity.
- High stress or pressure correlates with different failure risks.
Fuel_Consumption
- Used repeatedly at different depths.
- Low consumption is often associated with failed assets.
Thermal_Stress and Pressure_Temp_Interaction
- Capture specific mechanical strain or operational load effects.
- Moderate impact on classification in sub-branches.
General Observations
- The model is using engineered features effectively (e.g.,
Age_Vibration_Interaction,Operational_Stress_Index,Thermal_Stress), showing their importance. - Failure prediction (Class 1) is influenced by combinations of operational intensity and asset condition.
- Decision splits generally reflect intuitive domain logic: high stress, abnormal vibration, or low fuel efficiency increases failure risk.
Decision Tree Depth and Clarity
- Only the top 3 levels are visualized to maintain clarity.
- Beyond this, additional splits continue refining classifications using similar or supporting features.
This tree is a single estimator from the Random Forest — useful for interpretability, but the final prediction is made by aggregating across all trees.
K-fold on Best Rf model¶
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import (
roc_auc_score,
classification_report,
matthews_corrcoef,
confusion_matrix,
ConfusionMatrixDisplay,
precision_recall_curve,
auc # re-importing packages so I can use it without any error
)
# Define Stratified K-Fold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Store metrics for summary
roc_auc_scores = []
# Loop over each fold
for fold, (train_idx, val_idx) in enumerate(cv.split(X_train_balanced, y_train_balanced), start=1):
X_train_fold, X_val_fold = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
y_train_fold, y_val_fold = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]
# Train the best model on this fold
best_rf_F.fit(X_train_fold, y_train_fold)
y_pred = best_rf_F.predict(X_val_fold)
y_proba = best_rf_F.predict_proba(X_val_fold)[:, 1]
# AUC for this fold
roc_auc = roc_auc_score(y_val_fold, y_proba)
roc_auc_scores.append(roc_auc)
# Print classification report
print(f"\nFold {fold} - Classification Report:")
print(classification_report(y_val_fold, y_pred, digits=4))
print(f"Fold {fold} - AUC-ROC: {roc_auc:.4f}")
mcc = matthews_corrcoef(y_val_fold, y_pred)
print(f"MCC: {mcc:.4f}")
# Confusion Matrix
cm = confusion_matrix(y_val_fold, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_val_fold, y_proba)
roc_auc_curve = auc(fpr, tpr)
# historgraph of predicted probabilities
plt.figure(figsize=(6, 4))
plt.hist(y_proba[y_val_fold == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_proba[y_val_fold == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
# Plot ROC Curve
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC AUC = {roc_auc:.4f}")
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--') # Diagonal line
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()
# Print average AUC over all folds
print("\nAverage AUC-ROC across folds:")
print(f"{np.mean(roc_auc_scores):.4f} ± {np.std(roc_auc_scores):.4f}")
Fold 1 - Classification Report:
precision recall f1-score support
0 0.7487 0.9405 0.8337 6790
1 0.7560 0.3688 0.4957 3395
accuracy 0.7499 10185
macro avg 0.7524 0.6546 0.6647 10185
weighted avg 0.7512 0.7499 0.7211 10185
Fold 1 - AUC-ROC: 0.7769
MCC: 0.3951
Fold 2 - Classification Report:
precision recall f1-score support
0 0.7461 0.9471 0.8347 6790
1 0.7706 0.3552 0.4863 3395
accuracy 0.7498 10185
macro avg 0.7583 0.6512 0.6605 10185
weighted avg 0.7542 0.7498 0.7185 10185
Fold 2 - AUC-ROC: 0.7795
MCC: 0.3952
Fold 3 - Classification Report:
precision recall f1-score support
0 0.7450 0.9367 0.8299 6790
1 0.7391 0.3588 0.4830 3395
accuracy 0.7440 10185
macro avg 0.7420 0.6477 0.6565 10185
weighted avg 0.7430 0.7440 0.7143 10185
Fold 3 - AUC-ROC: 0.7722
MCC: 0.3782
Fold 4 - Classification Report:
precision recall f1-score support
0 0.7445 0.9412 0.8314 6790
1 0.7506 0.3539 0.4810 3394
accuracy 0.7455 10184
macro avg 0.7476 0.6475 0.6562 10184
weighted avg 0.7466 0.7455 0.7146 10184
Fold 4 - AUC-ROC: 0.7706
MCC: 0.3823
Fold 5 - Classification Report:
precision recall f1-score support
0 0.7460 0.9378 0.8310 6789
1 0.7441 0.3614 0.4865 3395
accuracy 0.7457 10184
macro avg 0.7450 0.6496 0.6588 10184
weighted avg 0.7454 0.7457 0.7162 10184
Fold 5 - AUC-ROC: 0.7721
MCC: 0.3830
Average AUC-ROC across folds: 0.7742 ± 0.0034
TAB NET¶
# Configuration
TARGET_COL = 'Failure_Event'
CATEGORICAL_COLS = ['Asset_Type', 'Location', 'Maintenance_History']
RANDOM_STATE = 42
N_SPLITS = 5
EPOCHS = 200
BATCH_SIZE = 64
# Identify categorical feature indices and dimensions
cat_idxs = [X_train_balanced.columns.get_loc(col) for col in CATEGORICAL_COLS]
cat_dims = [int(df[col].nunique()) for col in CATEGORICAL_COLS]
cat_emb_dim = [min(50, (dim + 1) // 2) for dim in cat_dims]
# Cross-validation setup
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
fold = 1
auc_scores = []
classes = np.unique(y_train_balanced)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train_balanced)
weights = dict(zip(classes, class_weights))
for train_idx, val_idx in skf.split(X_train_balanced, y_train_balanced):
print(f"\n==== Fold {fold} ====")
TAB_X_train, TAB_X_val = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
TAB_y_train,TAB_y_val = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]
# Initialize and train the model
clf_F = TabNetClassifier(
n_d=32, n_a=32, n_steps=5, gamma=1.5,
cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=cat_emb_dim,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-2),
scheduler_params={"step_size":10, "gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax',
seed=RANDOM_STATE,
verbose=1
)
clf_F.fit(
X_train=TAB_X_train.values, y_train=TAB_y_train.values,
eval_set=[(TAB_X_val.values, TAB_y_val.values)],
eval_name=['val'],
eval_metric=['auc'],
max_epochs=EPOCHS,
patience=20,
batch_size=BATCH_SIZE,
virtual_batch_size=128,
weights = weights
)
# Evaluation
y_pred_proba = clf_F.predict_proba(TAB_X_val.values)[:, 1]
y_pred = clf_F.predict(TAB_X_val.values)
auc = roc_auc_score(TAB_y_val.values, y_pred_proba)
print(classification_report(TAB_y_val.values, y_pred, digits=4))
print(f"Fold {fold} AUC: {auc:.4f}")
auc_scores.append(auc)
mcc = matthews_corrcoef(TAB_y_val, y_pred)
print(f"MCC: {mcc:.4f}")
# Confusion matrix
cm = confusion_matrix(TAB_y_val, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
from sklearn.metrics import roc_curve, auc
fpr, tpr, thresholds = roc_curve(TAB_y_val, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()
plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y_val == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_pred_proba[TAB_y_val == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
fold += 1
# Final average AUC
print(f"\n==== Cross-Validation Complete ====")
print(f"Mean AUC: {np.mean(auc_scores):.4f} | Std AUC: {np.std(auc_scores):.4f}")
clf_F.save_model("Failure_Event_tabnet_model")
==== Fold 1 ====
epoch 0 | loss: 0.71458 | val_auc: 0.58514 | 0:00:08s
epoch 1 | loss: 0.68575 | val_auc: 0.59763 | 0:00:15s
epoch 2 | loss: 0.68108 | val_auc: 0.61216 | 0:00:23s
epoch 3 | loss: 0.67893 | val_auc: 0.61338 | 0:00:31s
epoch 4 | loss: 0.67825 | val_auc: 0.61639 | 0:00:39s
epoch 5 | loss: 0.6744 | val_auc: 0.61304 | 0:00:47s
epoch 6 | loss: 0.67283 | val_auc: 0.6182 | 0:00:55s
epoch 7 | loss: 0.6697 | val_auc: 0.6287 | 0:01:04s
epoch 8 | loss: 0.67066 | val_auc: 0.63418 | 0:01:12s
epoch 9 | loss: 0.66839 | val_auc: 0.63039 | 0:01:21s
epoch 10 | loss: 0.67151 | val_auc: 0.62486 | 0:01:29s
epoch 11 | loss: 0.67072 | val_auc: 0.61996 | 0:01:38s
epoch 12 | loss: 0.67 | val_auc: 0.63585 | 0:01:49s
epoch 13 | loss: 0.66961 | val_auc: 0.62514 | 0:01:58s
epoch 14 | loss: 0.66823 | val_auc: 0.63112 | 0:02:07s
epoch 15 | loss: 0.66733 | val_auc: 0.63498 | 0:02:15s
epoch 16 | loss: 0.66758 | val_auc: 0.6419 | 0:02:24s
epoch 17 | loss: 0.66449 | val_auc: 0.64695 | 0:02:34s
epoch 18 | loss: 0.66605 | val_auc: 0.64354 | 0:02:42s
epoch 19 | loss: 0.6628 | val_auc: 0.64159 | 0:02:51s
epoch 20 | loss: 0.66557 | val_auc: 0.64505 | 0:03:00s
epoch 21 | loss: 0.66332 | val_auc: 0.64416 | 0:03:09s
epoch 22 | loss: 0.66111 | val_auc: 0.64345 | 0:03:19s
epoch 23 | loss: 0.66118 | val_auc: 0.64386 | 0:03:29s
epoch 24 | loss: 0.66016 | val_auc: 0.65331 | 0:03:39s
epoch 25 | loss: 0.65964 | val_auc: 0.65088 | 0:03:47s
epoch 26 | loss: 0.65719 | val_auc: 0.65574 | 0:03:56s
epoch 27 | loss: 0.6491 | val_auc: 0.671 | 0:04:04s
epoch 28 | loss: 0.63201 | val_auc: 0.66614 | 0:04:13s
epoch 29 | loss: 0.61198 | val_auc: 0.69575 | 0:04:22s
epoch 30 | loss: 0.58723 | val_auc: 0.70743 | 0:04:30s
epoch 31 | loss: 0.56325 | val_auc: 0.74714 | 0:04:40s
epoch 32 | loss: 0.54804 | val_auc: 0.7355 | 0:04:48s
epoch 33 | loss: 0.54409 | val_auc: 0.755 | 0:04:57s
epoch 34 | loss: 0.53694 | val_auc: 0.71618 | 0:05:06s
epoch 35 | loss: 0.52921 | val_auc: 0.73014 | 0:05:16s
epoch 36 | loss: 0.5318 | val_auc: 0.70244 | 0:05:25s
epoch 37 | loss: 0.52623 | val_auc: 0.71791 | 0:05:34s
epoch 38 | loss: 0.52309 | val_auc: 0.74466 | 0:05:45s
epoch 39 | loss: 0.52391 | val_auc: 0.73747 | 0:05:54s
epoch 40 | loss: 0.51575 | val_auc: 0.71515 | 0:06:04s
epoch 41 | loss: 0.51472 | val_auc: 0.75724 | 0:06:13s
epoch 42 | loss: 0.51097 | val_auc: 0.76026 | 0:06:23s
epoch 43 | loss: 0.51218 | val_auc: 0.71206 | 0:06:32s
epoch 44 | loss: 0.51132 | val_auc: 0.63665 | 0:06:41s
epoch 45 | loss: 0.51298 | val_auc: 0.75833 | 0:06:50s
epoch 46 | loss: 0.50666 | val_auc: 0.75484 | 0:07:00s
epoch 47 | loss: 0.50242 | val_auc: 0.74443 | 0:07:10s
epoch 48 | loss: 0.50202 | val_auc: 0.67663 | 0:07:20s
epoch 49 | loss: 0.50026 | val_auc: 0.72603 | 0:07:29s
epoch 50 | loss: 0.49844 | val_auc: 0.70896 | 0:07:39s
epoch 51 | loss: 0.49932 | val_auc: 0.69035 | 0:07:48s
epoch 52 | loss: 0.49436 | val_auc: 0.75091 | 0:07:57s
epoch 53 | loss: 0.4942 | val_auc: 0.73055 | 0:08:05s
epoch 54 | loss: 0.49471 | val_auc: 0.76354 | 0:08:14s
epoch 55 | loss: 0.49058 | val_auc: 0.66817 | 0:08:22s
epoch 56 | loss: 0.49392 | val_auc: 0.72228 | 0:08:31s
epoch 57 | loss: 0.49568 | val_auc: 0.74181 | 0:08:39s
epoch 58 | loss: 0.4904 | val_auc: 0.6678 | 0:08:48s
epoch 59 | loss: 0.49121 | val_auc: 0.73797 | 0:08:56s
epoch 60 | loss: 0.48977 | val_auc: 0.68416 | 0:09:05s
epoch 61 | loss: 0.48328 | val_auc: 0.76191 | 0:09:14s
epoch 62 | loss: 0.48807 | val_auc: 0.70157 | 0:09:22s
epoch 63 | loss: 0.48529 | val_auc: 0.69813 | 0:09:31s
epoch 64 | loss: 0.48864 | val_auc: 0.77281 | 0:09:40s
epoch 65 | loss: 0.4869 | val_auc: 0.73468 | 0:09:48s
epoch 66 | loss: 0.4867 | val_auc: 0.66701 | 0:09:56s
epoch 67 | loss: 0.48994 | val_auc: 0.73992 | 0:10:08s
epoch 68 | loss: 0.48501 | val_auc: 0.74521 | 0:10:28s
epoch 69 | loss: 0.48666 | val_auc: 0.74401 | 0:10:41s
epoch 70 | loss: 0.48619 | val_auc: 0.72978 | 0:10:53s
epoch 71 | loss: 0.48218 | val_auc: 0.74547 | 0:11:06s
epoch 72 | loss: 0.48358 | val_auc: 0.76971 | 0:11:17s
epoch 73 | loss: 0.47678 | val_auc: 0.75136 | 0:11:28s
epoch 74 | loss: 0.48124 | val_auc: 0.7399 | 0:11:39s
epoch 75 | loss: 0.48078 | val_auc: 0.77214 | 0:11:49s
epoch 76 | loss: 0.47987 | val_auc: 0.70952 | 0:11:58s
epoch 77 | loss: 0.4761 | val_auc: 0.73663 | 0:12:07s
epoch 78 | loss: 0.4817 | val_auc: 0.74254 | 0:12:16s
epoch 79 | loss: 0.47893 | val_auc: 0.74586 | 0:12:25s
epoch 80 | loss: 0.47633 | val_auc: 0.74668 | 0:12:35s
epoch 81 | loss: 0.47451 | val_auc: 0.75681 | 0:12:43s
epoch 82 | loss: 0.47598 | val_auc: 0.72161 | 0:12:52s
epoch 83 | loss: 0.47235 | val_auc: 0.73272 | 0:13:00s
epoch 84 | loss: 0.47724 | val_auc: 0.74193 | 0:13:09s
Early stopping occurred at epoch 84 with best_epoch = 64 and best_val_auc = 0.77281
precision recall f1-score support
0 0.7991 0.9710 0.8767 6790
1 0.8981 0.5116 0.6519 3395
accuracy 0.8179 10185
macro avg 0.8486 0.7413 0.7643 10185
weighted avg 0.8321 0.8179 0.8017 10185
Fold 1 AUC: 0.7728
MCC: 0.5801
==== Fold 2 ====
epoch 0 | loss: 0.7257 | val_auc: 0.53746 | 0:00:08s
epoch 1 | loss: 0.68761 | val_auc: 0.58829 | 0:00:16s
epoch 2 | loss: 0.68516 | val_auc: 0.59725 | 0:00:24s
epoch 3 | loss: 0.68323 | val_auc: 0.59705 | 0:00:32s
epoch 4 | loss: 0.67999 | val_auc: 0.60292 | 0:00:40s
epoch 5 | loss: 0.68029 | val_auc: 0.61052 | 0:00:48s
epoch 6 | loss: 0.67869 | val_auc: 0.6098 | 0:00:56s
epoch 7 | loss: 0.67758 | val_auc: 0.60534 | 0:01:03s
epoch 8 | loss: 0.6794 | val_auc: 0.60925 | 0:01:11s
epoch 9 | loss: 0.67816 | val_auc: 0.60583 | 0:01:19s
epoch 10 | loss: 0.67685 | val_auc: 0.60395 | 0:01:27s
epoch 11 | loss: 0.67939 | val_auc: 0.60684 | 0:01:34s
epoch 12 | loss: 0.68099 | val_auc: 0.59842 | 0:01:42s
epoch 13 | loss: 0.67846 | val_auc: 0.60648 | 0:01:49s
epoch 14 | loss: 0.67964 | val_auc: 0.60525 | 0:01:57s
epoch 15 | loss: 0.67872 | val_auc: 0.61034 | 0:02:04s
epoch 16 | loss: 0.68044 | val_auc: 0.60728 | 0:02:12s
epoch 17 | loss: 0.67801 | val_auc: 0.60308 | 0:02:20s
epoch 18 | loss: 0.67865 | val_auc: 0.60383 | 0:02:27s
epoch 19 | loss: 0.67857 | val_auc: 0.60348 | 0:02:35s
epoch 20 | loss: 0.67905 | val_auc: 0.60752 | 0:02:42s
epoch 21 | loss: 0.67902 | val_auc: 0.60377 | 0:02:50s
epoch 22 | loss: 0.67894 | val_auc: 0.60094 | 0:02:57s
epoch 23 | loss: 0.67621 | val_auc: 0.60858 | 0:03:05s
epoch 24 | loss: 0.67694 | val_auc: 0.61065 | 0:03:12s
epoch 25 | loss: 0.67631 | val_auc: 0.61026 | 0:03:20s
epoch 26 | loss: 0.67678 | val_auc: 0.61128 | 0:03:28s
epoch 27 | loss: 0.67715 | val_auc: 0.61047 | 0:03:35s
epoch 28 | loss: 0.67667 | val_auc: 0.60139 | 0:03:43s
epoch 29 | loss: 0.67677 | val_auc: 0.61035 | 0:03:51s
epoch 30 | loss: 0.67653 | val_auc: 0.61154 | 0:03:58s
epoch 31 | loss: 0.67662 | val_auc: 0.61235 | 0:04:06s
epoch 32 | loss: 0.67587 | val_auc: 0.61207 | 0:04:14s
epoch 33 | loss: 0.67723 | val_auc: 0.61022 | 0:04:21s
epoch 34 | loss: 0.67644 | val_auc: 0.61357 | 0:04:29s
epoch 35 | loss: 0.67582 | val_auc: 0.6133 | 0:04:37s
epoch 36 | loss: 0.67542 | val_auc: 0.61454 | 0:04:45s
epoch 37 | loss: 0.6765 | val_auc: 0.61416 | 0:04:52s
epoch 38 | loss: 0.67528 | val_auc: 0.61201 | 0:05:00s
epoch 39 | loss: 0.67663 | val_auc: 0.61751 | 0:05:07s
epoch 40 | loss: 0.67595 | val_auc: 0.61399 | 0:05:15s
epoch 41 | loss: 0.6749 | val_auc: 0.61576 | 0:05:23s
epoch 42 | loss: 0.67498 | val_auc: 0.61713 | 0:05:31s
epoch 43 | loss: 0.67593 | val_auc: 0.6126 | 0:05:40s
epoch 44 | loss: 0.67587 | val_auc: 0.61462 | 0:05:48s
epoch 45 | loss: 0.67589 | val_auc: 0.61386 | 0:05:56s
epoch 46 | loss: 0.67656 | val_auc: 0.61261 | 0:06:03s
epoch 47 | loss: 0.67645 | val_auc: 0.61577 | 0:06:11s
epoch 48 | loss: 0.67534 | val_auc: 0.61755 | 0:06:19s
epoch 49 | loss: 0.67701 | val_auc: 0.61101 | 0:06:28s
epoch 50 | loss: 0.67597 | val_auc: 0.61447 | 0:06:37s
epoch 51 | loss: 0.67624 | val_auc: 0.61677 | 0:06:46s
epoch 52 | loss: 0.67408 | val_auc: 0.61248 | 0:06:54s
epoch 53 | loss: 0.67487 | val_auc: 0.615 | 0:07:03s
epoch 54 | loss: 0.67807 | val_auc: 0.61561 | 0:07:11s
epoch 55 | loss: 0.67628 | val_auc: 0.61546 | 0:07:19s
epoch 56 | loss: 0.67576 | val_auc: 0.61431 | 0:07:27s
epoch 57 | loss: 0.67523 | val_auc: 0.61536 | 0:07:35s
epoch 58 | loss: 0.67759 | val_auc: 0.61397 | 0:07:43s
epoch 59 | loss: 0.67598 | val_auc: 0.61615 | 0:07:51s
epoch 60 | loss: 0.67776 | val_auc: 0.61799 | 0:07:59s
epoch 61 | loss: 0.67584 | val_auc: 0.61838 | 0:08:07s
epoch 62 | loss: 0.67573 | val_auc: 0.61738 | 0:08:15s
epoch 63 | loss: 0.67548 | val_auc: 0.62134 | 0:08:23s
epoch 64 | loss: 0.67541 | val_auc: 0.62044 | 0:08:31s
epoch 65 | loss: 0.67432 | val_auc: 0.62006 | 0:08:38s
epoch 66 | loss: 0.67386 | val_auc: 0.61963 | 0:08:46s
epoch 67 | loss: 0.67544 | val_auc: 0.6218 | 0:08:54s
epoch 68 | loss: 0.67277 | val_auc: 0.61955 | 0:09:02s
epoch 69 | loss: 0.6743 | val_auc: 0.62118 | 0:09:11s
epoch 70 | loss: 0.67474 | val_auc: 0.62311 | 0:09:19s
epoch 71 | loss: 0.67403 | val_auc: 0.62257 | 0:09:27s
epoch 72 | loss: 0.67199 | val_auc: 0.62108 | 0:09:35s
epoch 73 | loss: 0.67398 | val_auc: 0.6246 | 0:09:44s
epoch 74 | loss: 0.67331 | val_auc: 0.62247 | 0:09:52s
epoch 75 | loss: 0.67452 | val_auc: 0.62421 | 0:10:00s
epoch 76 | loss: 0.67449 | val_auc: 0.62342 | 0:10:09s
epoch 77 | loss: 0.67429 | val_auc: 0.61967 | 0:10:17s
epoch 78 | loss: 0.67434 | val_auc: 0.625 | 0:10:26s
epoch 79 | loss: 0.6737 | val_auc: 0.62474 | 0:10:34s
epoch 80 | loss: 0.67297 | val_auc: 0.62264 | 0:10:45s
epoch 81 | loss: 0.67493 | val_auc: 0.6265 | 0:10:55s
epoch 82 | loss: 0.67166 | val_auc: 0.62601 | 0:11:04s
epoch 83 | loss: 0.67239 | val_auc: 0.62398 | 0:11:14s
epoch 84 | loss: 0.67228 | val_auc: 0.62457 | 0:11:23s
epoch 85 | loss: 0.67285 | val_auc: 0.62422 | 0:11:32s
epoch 86 | loss: 0.67159 | val_auc: 0.62474 | 0:11:41s
epoch 87 | loss: 0.67553 | val_auc: 0.62966 | 0:11:51s
epoch 88 | loss: 0.67158 | val_auc: 0.62621 | 0:11:59s
epoch 89 | loss: 0.67214 | val_auc: 0.62551 | 0:12:08s
epoch 90 | loss: 0.67276 | val_auc: 0.62614 | 0:12:17s
epoch 91 | loss: 0.67125 | val_auc: 0.62688 | 0:12:28s
epoch 92 | loss: 0.67053 | val_auc: 0.62729 | 0:12:38s
epoch 93 | loss: 0.67084 | val_auc: 0.62689 | 0:12:47s
epoch 94 | loss: 0.67016 | val_auc: 0.62744 | 0:12:56s
epoch 95 | loss: 0.6711 | val_auc: 0.62668 | 0:13:07s
epoch 96 | loss: 0.67111 | val_auc: 0.62823 | 0:13:20s
epoch 97 | loss: 0.67122 | val_auc: 0.62452 | 0:13:30s
epoch 98 | loss: 0.67187 | val_auc: 0.62327 | 0:13:40s
epoch 99 | loss: 0.67294 | val_auc: 0.63076 | 0:13:49s
epoch 100| loss: 0.67022 | val_auc: 0.63146 | 0:13:59s
epoch 101| loss: 0.67218 | val_auc: 0.63074 | 0:14:08s
epoch 102| loss: 0.67101 | val_auc: 0.62925 | 0:14:18s
epoch 103| loss: 0.67286 | val_auc: 0.63147 | 0:14:29s
epoch 104| loss: 0.67242 | val_auc: 0.63113 | 0:14:39s
epoch 105| loss: 0.67051 | val_auc: 0.6318 | 0:14:48s
epoch 106| loss: 0.67121 | val_auc: 0.63181 | 0:14:57s
epoch 107| loss: 0.67174 | val_auc: 0.63271 | 0:15:07s
epoch 108| loss: 0.67015 | val_auc: 0.63342 | 0:15:16s
epoch 109| loss: 0.67167 | val_auc: 0.63363 | 0:15:25s
epoch 110| loss: 0.67037 | val_auc: 0.63235 | 0:15:35s
epoch 111| loss: 0.66924 | val_auc: 0.63267 | 0:15:44s
epoch 112| loss: 0.66941 | val_auc: 0.63205 | 0:15:53s
epoch 113| loss: 0.67201 | val_auc: 0.63283 | 0:16:03s
epoch 114| loss: 0.67087 | val_auc: 0.63305 | 0:16:13s
epoch 115| loss: 0.6702 | val_auc: 0.6352 | 0:16:23s
epoch 116| loss: 0.67071 | val_auc: 0.63337 | 0:16:33s
epoch 117| loss: 0.67049 | val_auc: 0.6306 | 0:16:42s
epoch 118| loss: 0.66964 | val_auc: 0.63496 | 0:16:52s
epoch 119| loss: 0.66975 | val_auc: 0.63274 | 0:17:02s
epoch 120| loss: 0.67164 | val_auc: 0.63474 | 0:17:14s
epoch 121| loss: 0.67245 | val_auc: 0.63343 | 0:17:23s
epoch 122| loss: 0.67056 | val_auc: 0.63268 | 0:17:34s
epoch 123| loss: 0.67062 | val_auc: 0.63371 | 0:17:45s
epoch 124| loss: 0.67194 | val_auc: 0.63116 | 0:17:56s
epoch 125| loss: 0.66845 | val_auc: 0.633 | 0:18:07s
epoch 126| loss: 0.66861 | val_auc: 0.63384 | 0:18:17s
epoch 127| loss: 0.67012 | val_auc: 0.63194 | 0:18:26s
epoch 128| loss: 0.67204 | val_auc: 0.62821 | 0:18:36s
epoch 129| loss: 0.67145 | val_auc: 0.63158 | 0:18:45s
epoch 130| loss: 0.67129 | val_auc: 0.63305 | 0:18:55s
epoch 131| loss: 0.67133 | val_auc: 0.6317 | 0:19:04s
epoch 132| loss: 0.6705 | val_auc: 0.63297 | 0:19:13s
epoch 133| loss: 0.67003 | val_auc: 0.63473 | 0:19:23s
epoch 134| loss: 0.67035 | val_auc: 0.63209 | 0:19:32s
epoch 135| loss: 0.67083 | val_auc: 0.63066 | 0:19:42s
Early stopping occurred at epoch 135 with best_epoch = 115 and best_val_auc = 0.6352
precision recall f1-score support
0 0.7645 0.4866 0.5947 6790
1 0.4054 0.7001 0.5135 3395
accuracy 0.5578 10185
macro avg 0.5849 0.5934 0.5541 10185
weighted avg 0.6448 0.5578 0.5676 10185
Fold 2 AUC: 0.6352
MCC: 0.1781
==== Fold 3 ====
epoch 0 | loss: 0.7218 | val_auc: 0.55885 | 0:00:09s
epoch 1 | loss: 0.6882 | val_auc: 0.59358 | 0:00:19s
epoch 2 | loss: 0.68031 | val_auc: 0.59922 | 0:00:30s
epoch 3 | loss: 0.67972 | val_auc: 0.60477 | 0:00:40s
epoch 4 | loss: 0.67724 | val_auc: 0.61739 | 0:00:49s
epoch 5 | loss: 0.67683 | val_auc: 0.60908 | 0:00:58s
epoch 6 | loss: 0.67877 | val_auc: 0.59999 | 0:01:08s
epoch 7 | loss: 0.67584 | val_auc: 0.60837 | 0:01:17s
epoch 8 | loss: 0.67571 | val_auc: 0.60798 | 0:01:26s
epoch 9 | loss: 0.67433 | val_auc: 0.60073 | 0:01:36s
epoch 10 | loss: 0.67381 | val_auc: 0.61809 | 0:01:45s
epoch 11 | loss: 0.67232 | val_auc: 0.61666 | 0:01:55s
epoch 12 | loss: 0.67153 | val_auc: 0.61992 | 0:02:04s
epoch 13 | loss: 0.6725 | val_auc: 0.61924 | 0:02:13s
epoch 14 | loss: 0.67408 | val_auc: 0.61111 | 0:02:23s
epoch 15 | loss: 0.67444 | val_auc: 0.61399 | 0:02:33s
epoch 16 | loss: 0.67469 | val_auc: 0.61232 | 0:02:42s
epoch 17 | loss: 0.67318 | val_auc: 0.61808 | 0:02:52s
epoch 18 | loss: 0.67378 | val_auc: 0.62536 | 0:03:01s
epoch 19 | loss: 0.67363 | val_auc: 0.61995 | 0:03:11s
epoch 20 | loss: 0.6731 | val_auc: 0.62426 | 0:03:20s
epoch 21 | loss: 0.67163 | val_auc: 0.62367 | 0:03:30s
epoch 22 | loss: 0.6694 | val_auc: 0.63273 | 0:03:39s
epoch 23 | loss: 0.66939 | val_auc: 0.63642 | 0:03:48s
epoch 24 | loss: 0.66927 | val_auc: 0.63103 | 0:03:58s
epoch 25 | loss: 0.66685 | val_auc: 0.6307 | 0:04:07s
epoch 26 | loss: 0.66971 | val_auc: 0.62782 | 0:04:19s
epoch 27 | loss: 0.66645 | val_auc: 0.63065 | 0:04:30s
epoch 28 | loss: 0.66788 | val_auc: 0.6314 | 0:04:40s
epoch 29 | loss: 0.66949 | val_auc: 0.63063 | 0:04:50s
epoch 30 | loss: 0.66758 | val_auc: 0.63321 | 0:04:59s
epoch 31 | loss: 0.6698 | val_auc: 0.63539 | 0:05:09s
epoch 32 | loss: 0.66682 | val_auc: 0.63332 | 0:05:18s
epoch 33 | loss: 0.66664 | val_auc: 0.63623 | 0:05:28s
epoch 34 | loss: 0.66574 | val_auc: 0.64074 | 0:05:37s
epoch 35 | loss: 0.66666 | val_auc: 0.64068 | 0:05:46s
epoch 36 | loss: 0.6663 | val_auc: 0.63394 | 0:05:55s
epoch 37 | loss: 0.66692 | val_auc: 0.63794 | 0:06:04s
epoch 38 | loss: 0.66674 | val_auc: 0.63428 | 0:06:13s
epoch 39 | loss: 0.66658 | val_auc: 0.63457 | 0:06:22s
epoch 40 | loss: 0.66576 | val_auc: 0.63829 | 0:06:31s
epoch 41 | loss: 0.66497 | val_auc: 0.64095 | 0:06:39s
epoch 42 | loss: 0.66518 | val_auc: 0.64047 | 0:06:48s
epoch 43 | loss: 0.66518 | val_auc: 0.63705 | 0:06:56s
epoch 44 | loss: 0.66577 | val_auc: 0.63856 | 0:07:04s
epoch 45 | loss: 0.66607 | val_auc: 0.64073 | 0:07:14s
epoch 46 | loss: 0.66486 | val_auc: 0.638 | 0:07:22s
epoch 47 | loss: 0.66572 | val_auc: 0.64164 | 0:07:30s
epoch 48 | loss: 0.66593 | val_auc: 0.64362 | 0:07:38s
epoch 49 | loss: 0.66622 | val_auc: 0.63904 | 0:07:46s
epoch 50 | loss: 0.66453 | val_auc: 0.63709 | 0:07:54s
epoch 51 | loss: 0.66724 | val_auc: 0.64113 | 0:08:02s
epoch 52 | loss: 0.6647 | val_auc: 0.64107 | 0:08:09s
epoch 53 | loss: 0.66428 | val_auc: 0.63805 | 0:08:17s
epoch 54 | loss: 0.66469 | val_auc: 0.63981 | 0:08:26s
epoch 55 | loss: 0.66551 | val_auc: 0.6389 | 0:08:34s
epoch 56 | loss: 0.66675 | val_auc: 0.64354 | 0:08:42s
epoch 57 | loss: 0.66501 | val_auc: 0.63528 | 0:08:49s
epoch 58 | loss: 0.66621 | val_auc: 0.63981 | 0:08:57s
epoch 59 | loss: 0.66515 | val_auc: 0.64207 | 0:09:05s
epoch 60 | loss: 0.66393 | val_auc: 0.64397 | 0:09:13s
epoch 61 | loss: 0.66424 | val_auc: 0.64084 | 0:09:20s
epoch 62 | loss: 0.66511 | val_auc: 0.64313 | 0:09:28s
epoch 63 | loss: 0.66572 | val_auc: 0.64024 | 0:09:36s
epoch 64 | loss: 0.66487 | val_auc: 0.64332 | 0:09:44s
epoch 65 | loss: 0.66401 | val_auc: 0.64226 | 0:09:51s
epoch 66 | loss: 0.66402 | val_auc: 0.63947 | 0:09:59s
epoch 67 | loss: 0.66375 | val_auc: 0.64397 | 0:10:07s
epoch 68 | loss: 0.6645 | val_auc: 0.64077 | 0:10:15s
epoch 69 | loss: 0.66196 | val_auc: 0.64171 | 0:10:22s
epoch 70 | loss: 0.66386 | val_auc: 0.64178 | 0:10:31s
epoch 71 | loss: 0.66337 | val_auc: 0.64024 | 0:10:39s
epoch 72 | loss: 0.66184 | val_auc: 0.63714 | 0:10:47s
epoch 73 | loss: 0.66477 | val_auc: 0.63992 | 0:10:55s
epoch 74 | loss: 0.66346 | val_auc: 0.64097 | 0:11:03s
epoch 75 | loss: 0.66362 | val_auc: 0.63924 | 0:11:11s
epoch 76 | loss: 0.66526 | val_auc: 0.63897 | 0:11:19s
epoch 77 | loss: 0.66341 | val_auc: 0.64222 | 0:11:27s
epoch 78 | loss: 0.66433 | val_auc: 0.64289 | 0:11:35s
epoch 79 | loss: 0.6638 | val_auc: 0.64276 | 0:11:43s
epoch 80 | loss: 0.66278 | val_auc: 0.64011 | 0:11:50s
Early stopping occurred at epoch 80 with best_epoch = 60 and best_val_auc = 0.64397
precision recall f1-score support
0 0.7683 0.5563 0.6453 6790
1 0.4282 0.6645 0.5208 3395
accuracy 0.5923 10185
macro avg 0.5982 0.6104 0.5830 10185
weighted avg 0.6549 0.5923 0.6038 10185
Fold 3 AUC: 0.6440
MCC: 0.2083
==== Fold 4 ====
epoch 0 | loss: 0.71392 | val_auc: 0.59768 | 0:00:08s
epoch 1 | loss: 0.68596 | val_auc: 0.59188 | 0:00:16s
epoch 2 | loss: 0.68274 | val_auc: 0.60029 | 0:00:26s
epoch 3 | loss: 0.67813 | val_auc: 0.60294 | 0:00:36s
epoch 4 | loss: 0.67905 | val_auc: 0.61214 | 0:00:44s
epoch 5 | loss: 0.67735 | val_auc: 0.61377 | 0:00:52s
epoch 6 | loss: 0.67779 | val_auc: 0.61148 | 0:01:00s
epoch 7 | loss: 0.67923 | val_auc: 0.60758 | 0:01:08s
epoch 8 | loss: 0.67706 | val_auc: 0.60488 | 0:01:16s
epoch 9 | loss: 0.67651 | val_auc: 0.60847 | 0:01:25s
epoch 10 | loss: 0.67599 | val_auc: 0.60629 | 0:01:33s
epoch 11 | loss: 0.67524 | val_auc: 0.6185 | 0:01:41s
epoch 12 | loss: 0.67545 | val_auc: 0.6192 | 0:01:49s
epoch 13 | loss: 0.67245 | val_auc: 0.62488 | 0:01:57s
epoch 14 | loss: 0.67228 | val_auc: 0.624 | 0:02:05s
epoch 15 | loss: 0.67096 | val_auc: 0.62234 | 0:02:13s
epoch 16 | loss: 0.67246 | val_auc: 0.62738 | 0:02:21s
epoch 17 | loss: 0.6703 | val_auc: 0.63416 | 0:02:29s
epoch 18 | loss: 0.66959 | val_auc: 0.62254 | 0:02:37s
epoch 19 | loss: 0.67049 | val_auc: 0.62396 | 0:02:45s
epoch 20 | loss: 0.66794 | val_auc: 0.63425 | 0:02:53s
epoch 21 | loss: 0.66636 | val_auc: 0.63345 | 0:03:01s
epoch 22 | loss: 0.66951 | val_auc: 0.63817 | 0:03:09s
epoch 23 | loss: 0.66621 | val_auc: 0.63856 | 0:03:17s
epoch 24 | loss: 0.66728 | val_auc: 0.6298 | 0:03:26s
epoch 25 | loss: 0.66853 | val_auc: 0.63756 | 0:03:37s
epoch 26 | loss: 0.66579 | val_auc: 0.64093 | 0:03:47s
epoch 27 | loss: 0.66558 | val_auc: 0.63392 | 0:03:55s
epoch 28 | loss: 0.66841 | val_auc: 0.62935 | 0:04:04s
epoch 29 | loss: 0.66533 | val_auc: 0.63438 | 0:04:12s
epoch 30 | loss: 0.66557 | val_auc: 0.63502 | 0:08:11s
epoch 31 | loss: 0.6642 | val_auc: 0.63649 | 0:08:19s
epoch 32 | loss: 0.66306 | val_auc: 0.63905 | 0:08:26s
epoch 33 | loss: 0.66493 | val_auc: 0.63957 | 0:21:32s
epoch 34 | loss: 0.66514 | val_auc: 0.64148 | 0:21:40s
epoch 35 | loss: 0.66407 | val_auc: 0.63754 | 0:21:46s
epoch 36 | loss: 0.66367 | val_auc: 0.64246 | 0:21:53s
epoch 37 | loss: 0.66323 | val_auc: 0.63896 | 0:21:59s
epoch 38 | loss: 0.66421 | val_auc: 0.64307 | 0:22:06s
epoch 39 | loss: 0.66146 | val_auc: 0.64007 | 0:22:13s
epoch 40 | loss: 0.66237 | val_auc: 0.64468 | 0:41:44s
epoch 41 | loss: 0.66187 | val_auc: 0.64453 | 0:41:56s
epoch 42 | loss: 0.65994 | val_auc: 0.64117 | 0:42:04s
epoch 43 | loss: 0.65739 | val_auc: 0.63615 | 0:42:12s
epoch 44 | loss: 0.65858 | val_auc: 0.64934 | 0:42:21s
epoch 45 | loss: 0.66051 | val_auc: 0.64757 | 0:42:27s
epoch 46 | loss: 0.65543 | val_auc: 0.64908 | 0:42:34s
epoch 47 | loss: 0.65763 | val_auc: 0.6431 | 0:42:40s
epoch 48 | loss: 0.65457 | val_auc: 0.64378 | 0:42:46s
epoch 49 | loss: 0.6593 | val_auc: 0.64737 | 0:42:53s
epoch 50 | loss: 0.65383 | val_auc: 0.65483 | 0:42:59s
epoch 51 | loss: 0.65388 | val_auc: 0.65875 | 0:43:06s
epoch 52 | loss: 0.65233 | val_auc: 0.65741 | 0:43:13s
epoch 53 | loss: 0.64907 | val_auc: 0.66099 | 0:43:19s
epoch 54 | loss: 0.64801 | val_auc: 0.66493 | 0:43:26s
epoch 55 | loss: 0.64823 | val_auc: 0.65804 | 0:43:34s
epoch 56 | loss: 0.64292 | val_auc: 0.66288 | 0:43:41s
epoch 57 | loss: 0.64196 | val_auc: 0.67365 | 0:43:48s
epoch 58 | loss: 0.6332 | val_auc: 0.66825 | 0:43:55s
epoch 59 | loss: 0.62525 | val_auc: 0.68516 | 0:44:02s
epoch 60 | loss: 0.61705 | val_auc: 0.69033 | 0:44:09s
epoch 61 | loss: 0.60889 | val_auc: 0.71261 | 0:44:17s
epoch 62 | loss: 0.59198 | val_auc: 0.72509 | 0:44:24s
epoch 63 | loss: 0.58121 | val_auc: 0.73198 | 0:44:31s
epoch 64 | loss: 0.57331 | val_auc: 0.7366 | 0:44:38s
epoch 65 | loss: 0.56609 | val_auc: 0.68992 | 0:44:45s
epoch 66 | loss: 0.55844 | val_auc: 0.71954 | 0:44:53s
epoch 67 | loss: 0.55355 | val_auc: 0.72913 | 0:45:00s
epoch 68 | loss: 0.54844 | val_auc: 0.72841 | 0:45:07s
epoch 69 | loss: 0.54431 | val_auc: 0.733 | 0:45:14s
epoch 70 | loss: 0.53927 | val_auc: 0.71739 | 0:45:22s
epoch 71 | loss: 0.5337 | val_auc: 0.73548 | 0:45:29s
epoch 72 | loss: 0.53193 | val_auc: 0.70763 | 0:45:36s
epoch 73 | loss: 0.53044 | val_auc: 0.71388 | 0:45:43s
epoch 74 | loss: 0.52454 | val_auc: 0.73895 | 0:45:51s
epoch 75 | loss: 0.5225 | val_auc: 0.73833 | 0:45:58s
epoch 76 | loss: 0.51702 | val_auc: 0.74352 | 0:46:05s
epoch 77 | loss: 0.52091 | val_auc: 0.72115 | 0:46:13s
epoch 78 | loss: 0.51458 | val_auc: 0.73361 | 0:46:20s
epoch 79 | loss: 0.51612 | val_auc: 0.71282 | 0:46:28s
epoch 80 | loss: 0.51007 | val_auc: 0.74158 | 0:46:35s
epoch 81 | loss: 0.51016 | val_auc: 0.73994 | 0:46:43s
epoch 82 | loss: 0.50892 | val_auc: 0.65165 | 0:46:50s
epoch 83 | loss: 0.51212 | val_auc: 0.73442 | 0:46:57s
epoch 84 | loss: 0.50758 | val_auc: 0.734 | 0:47:05s
epoch 85 | loss: 0.50531 | val_auc: 0.74106 | 0:47:12s
epoch 86 | loss: 0.50151 | val_auc: 0.71251 | 0:47:20s
epoch 87 | loss: 0.50668 | val_auc: 0.72136 | 0:47:27s
epoch 88 | loss: 0.50778 | val_auc: 0.73639 | 0:47:35s
epoch 89 | loss: 0.50359 | val_auc: 0.72468 | 0:47:44s
epoch 90 | loss: 0.50021 | val_auc: 0.72629 | 0:47:52s
epoch 91 | loss: 0.50154 | val_auc: 0.72505 | 0:47:59s
epoch 92 | loss: 0.5002 | val_auc: 0.74648 | 0:48:07s
epoch 93 | loss: 0.49619 | val_auc: 0.72082 | 0:48:14s
epoch 94 | loss: 0.49808 | val_auc: 0.70147 | 0:48:22s
epoch 95 | loss: 0.49894 | val_auc: 0.7272 | 0:48:29s
epoch 96 | loss: 0.49718 | val_auc: 0.73616 | 0:48:37s
epoch 97 | loss: 0.49433 | val_auc: 0.75113 | 0:48:44s
epoch 98 | loss: 0.49574 | val_auc: 0.72391 | 0:48:52s
epoch 99 | loss: 0.49716 | val_auc: 0.69928 | 0:48:59s
epoch 100| loss: 0.49181 | val_auc: 0.73015 | 0:49:07s
epoch 101| loss: 0.49825 | val_auc: 0.70603 | 0:49:14s
epoch 102| loss: 0.49114 | val_auc: 0.75679 | 0:49:22s
epoch 103| loss: 0.49518 | val_auc: 0.7136 | 0:49:30s
epoch 104| loss: 0.49372 | val_auc: 0.72138 | 0:49:37s
epoch 105| loss: 0.49357 | val_auc: 0.7604 | 0:49:45s
epoch 106| loss: 0.49471 | val_auc: 0.75576 | 0:49:53s
epoch 107| loss: 0.4902 | val_auc: 0.68895 | 0:50:00s
epoch 108| loss: 0.49344 | val_auc: 0.49985 | 0:50:08s
epoch 109| loss: 0.4924 | val_auc: 0.75216 | 0:50:16s
epoch 110| loss: 0.48773 | val_auc: 0.75372 | 0:50:23s
epoch 111| loss: 0.48885 | val_auc: 0.75479 | 0:50:31s
epoch 112| loss: 0.4902 | val_auc: 0.55466 | 0:50:39s
epoch 113| loss: 0.48912 | val_auc: 0.73351 | 0:50:46s
epoch 114| loss: 0.48739 | val_auc: 0.72823 | 0:50:54s
epoch 115| loss: 0.48659 | val_auc: 0.75073 | 0:51:02s
epoch 116| loss: 0.48825 | val_auc: 0.7511 | 0:51:10s
epoch 117| loss: 0.4894 | val_auc: 0.68768 | 0:51:18s
epoch 118| loss: 0.48623 | val_auc: 0.69732 | 0:51:26s
epoch 119| loss: 0.4896 | val_auc: 0.74797 | 0:51:34s
epoch 120| loss: 0.48947 | val_auc: 0.75605 | 0:51:43s
epoch 121| loss: 0.48622 | val_auc: 0.74973 | 0:51:51s
epoch 122| loss: 0.48344 | val_auc: 0.71532 | 0:51:58s
epoch 123| loss: 0.48918 | val_auc: 0.73061 | 0:52:06s
epoch 124| loss: 0.48184 | val_auc: 0.64659 | 0:52:14s
epoch 125| loss: 0.48205 | val_auc: 0.74083 | 0:52:22s
Early stopping occurred at epoch 125 with best_epoch = 105 and best_val_auc = 0.7604
precision recall f1-score support
0 0.7906 1.0000 0.8831 6790
1 1.0000 0.4702 0.6397 3394
accuracy 0.8234 10184
macro avg 0.8953 0.7351 0.7614 10184
weighted avg 0.8604 0.8234 0.8020 10184
Fold 4 AUC: 0.7604
MCC: 0.6097
==== Fold 5 ====
epoch 0 | loss: 0.71305 | val_auc: 0.57825 | 0:00:07s
epoch 1 | loss: 0.68662 | val_auc: 0.58927 | 0:00:15s
epoch 2 | loss: 0.68255 | val_auc: 0.60638 | 0:00:23s
epoch 3 | loss: 0.67729 | val_auc: 0.6091 | 0:00:31s
epoch 4 | loss: 0.67632 | val_auc: 0.61426 | 0:00:39s
epoch 5 | loss: 0.67503 | val_auc: 0.6169 | 0:00:46s
epoch 6 | loss: 0.67486 | val_auc: 0.61544 | 0:00:54s
epoch 7 | loss: 0.6751 | val_auc: 0.61394 | 0:01:02s
epoch 8 | loss: 0.67459 | val_auc: 0.61279 | 0:01:10s
epoch 9 | loss: 0.67463 | val_auc: 0.61803 | 0:01:18s
epoch 10 | loss: 0.67213 | val_auc: 0.62082 | 0:01:25s
epoch 11 | loss: 0.67261 | val_auc: 0.61999 | 0:01:33s
epoch 12 | loss: 0.67108 | val_auc: 0.61774 | 0:01:41s
epoch 13 | loss: 0.66813 | val_auc: 0.62669 | 0:01:48s
epoch 14 | loss: 0.66755 | val_auc: 0.62806 | 0:01:56s
epoch 15 | loss: 0.66766 | val_auc: 0.63502 | 0:02:04s
epoch 16 | loss: 0.66792 | val_auc: 0.6349 | 0:02:12s
epoch 17 | loss: 0.66508 | val_auc: 0.63258 | 0:02:20s
epoch 18 | loss: 0.66732 | val_auc: 0.63808 | 0:02:28s
epoch 19 | loss: 0.66639 | val_auc: 0.62917 | 0:02:36s
epoch 20 | loss: 0.66791 | val_auc: 0.63797 | 0:02:43s
epoch 21 | loss: 0.66584 | val_auc: 0.63302 | 0:02:51s
epoch 22 | loss: 0.66386 | val_auc: 0.63666 | 0:02:59s
epoch 23 | loss: 0.66522 | val_auc: 0.64241 | 0:03:07s
epoch 24 | loss: 0.66441 | val_auc: 0.63817 | 0:03:16s
epoch 25 | loss: 0.66422 | val_auc: 0.63368 | 0:03:25s
epoch 26 | loss: 0.66616 | val_auc: 0.63348 | 0:03:34s
epoch 27 | loss: 0.66565 | val_auc: 0.63597 | 0:03:42s
epoch 28 | loss: 0.66532 | val_auc: 0.63909 | 0:03:50s
epoch 29 | loss: 0.66665 | val_auc: 0.63869 | 0:03:57s
epoch 30 | loss: 0.66418 | val_auc: 0.63934 | 0:04:06s
epoch 31 | loss: 0.66443 | val_auc: 0.64022 | 0:04:14s
epoch 32 | loss: 0.6626 | val_auc: 0.64237 | 0:04:22s
epoch 33 | loss: 0.66573 | val_auc: 0.63978 | 0:04:31s
epoch 34 | loss: 0.66166 | val_auc: 0.64345 | 0:04:39s
epoch 35 | loss: 0.66483 | val_auc: 0.64611 | 0:04:48s
epoch 36 | loss: 0.66584 | val_auc: 0.64238 | 0:04:56s
epoch 37 | loss: 0.66269 | val_auc: 0.64274 | 0:05:05s
epoch 38 | loss: 0.66112 | val_auc: 0.64267 | 0:05:14s
epoch 39 | loss: 0.66537 | val_auc: 0.6401 | 0:05:23s
epoch 40 | loss: 0.66322 | val_auc: 0.64375 | 0:05:32s
epoch 41 | loss: 0.66461 | val_auc: 0.64253 | 0:05:41s
epoch 42 | loss: 0.66167 | val_auc: 0.64271 | 0:05:50s
epoch 43 | loss: 0.66472 | val_auc: 0.64603 | 0:05:59s
epoch 44 | loss: 0.66321 | val_auc: 0.64158 | 0:06:08s
epoch 45 | loss: 0.6622 | val_auc: 0.64519 | 0:06:17s
epoch 46 | loss: 0.66451 | val_auc: 0.64448 | 0:06:27s
epoch 47 | loss: 0.66212 | val_auc: 0.64417 | 0:06:36s
epoch 48 | loss: 0.66134 | val_auc: 0.64511 | 0:06:45s
epoch 49 | loss: 0.6631 | val_auc: 0.6443 | 0:06:54s
epoch 50 | loss: 0.66151 | val_auc: 0.64272 | 0:07:03s
epoch 51 | loss: 0.6633 | val_auc: 0.64404 | 0:07:13s
epoch 52 | loss: 0.66263 | val_auc: 0.64297 | 0:07:22s
epoch 53 | loss: 0.66179 | val_auc: 0.6436 | 0:07:31s
epoch 54 | loss: 0.66498 | val_auc: 0.64279 | 0:07:41s
epoch 55 | loss: 0.66353 | val_auc: 0.64476 | 0:07:50s
Early stopping occurred at epoch 55 with best_epoch = 35 and best_val_auc = 0.64611
precision recall f1-score support
0 0.7645 0.5733 0.6552 6789
1 0.4312 0.6468 0.5174 3395
accuracy 0.5978 10184
macro avg 0.5978 0.6101 0.5863 10184
weighted avg 0.6534 0.5978 0.6093 10184
Fold 5 AUC: 0.6461
MCC: 0.2075
==== Cross-Validation Complete ==== Mean AUC: 0.6917 | Std AUC: 0.0614 Successfully saved model at Failure_Event_tabnet_model.zip
'Failure_Event_tabnet_model.zip'
TAB NET¶
Tab net is a NN that is specifically designed for tabular data.
- Traditional deep learning models (like MLPs) struggle with tabular data. TabNet, however, was specifically designed for structured datasets — which have usually been better handled by models like XGBoost or Random Forests.
- TabNet processes data in steps, and at each step, it uses an attention mechanism to decide
- Which features to focus on
- How much each feature should influence the prediction
- This is very different from tree-based models that use greedy splits or MLPs that treat all input features equally at all times.
- TabNet promotes sparsity in feature usage
- Each decision step uses only a small subset of features.
- This makes the model interpretable and efficient
- This is controlled by the gamma parameter and the entmax activation — which encourages attention to only a few important inputs.
- Handles Categorical Variables Natively
- Uses embeddings for categorical variables (like NLP models)
- No need for manual one-hot encoding or label encoding.
- Learns better representations for categories during training.
- TabNet is trained end-to-end using gradient descent, which makes it
- More flexible and scalable
- Capable of benefiting from powerful optimization tools like learning rate schedulers, early stopping, etc.
- Comparision of models
| Feature | TabNet | XGBoost / Random Forest | MLP (Feedforward Neural Net) |
|---|---|---|---|
| Feature Selection | Attention-based, sparse | Tree splits | Implicit (weights) |
| Interpretability | High (built-in feature masks) | Medium (requires SHAP/LIME) | Low |
| Categorical Handling | Native (embeddings) | Manual encoding (one-hot/label) | Manual encoding (one-hot/label) |
| Sequential Decision Steps | Yes | No | No |
| Handles Imbalanced Data | Yes (class weights + attention) | Yes (with tuning) | Needs balancing (e.g., SMOTE) |
| Training | End-to-end gradient descent | Gradient boosting / bagging | End-to-end gradient descent |
| Performance on Tabular | Very competitive | Strong baseline | Often underperforms |
TAB NET specific parameters¶
clf_F = TabNetClassifier(
n_d=32, n_a=32, n_steps=5, gamma=1.5,
cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=cat_emb_dim,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-2),
scheduler_params={"step_size":10, "gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax',
seed=RANDOM_STATE,
verbose=1
)
n_d=32: Number of dimensions for the decision step output. Controls the size of the vector that holds learned representations at each step.n_a=32: Number of dimensions for the attention step output. Controls how much attention each feature gets during selection. Usually set equal ton_d.n_steps=5: Number of sequential decision steps. Each step decides which features to focus on and learns new representations. More steps allow more complex reasoning but increase computation.gamma=1.5: Controls the sparsity of feature selection. Highergammaresults in more sparse attention (focus on fewer features). This enhances interpretability and helps regularize the model.cat_idxs=cat_idxs: List of column indices for categorical features in the dataset. Tells TabNet which features require embedding.cat_dims=cat_dims: List containing the number of unique values for each categorical column. Used to define embedding layer dimensions.cat_emb_dim=cat_emb_dim: List of embedding dimensions for each categorical feature. Typically set using a rule likemin(50, (cat_dim + 1) // 2).optimizer_fn=torch.optim.Adam: Specifies the optimizer to use. Adam optimizer is used here for its efficiency and adaptive learning rate.optimizer_params=dict(lr=1e-2): Sets the initial learning rate to 0.01.scheduler_params={"step_size":10, "gamma":0.9}: Learning rate scheduler parameters. Every 10 epochs, the learning rate is multiplied by 0.9 to gradually reduce it.scheduler_fn=torch.optim.lr_scheduler.StepLR: The function used to reduce the learning rate at fixed intervals. StepLR is a simple and commonly used scheduler.mask_type='entmax': Specifies the attention mask type for feature selection.entmaxcreates sparse masks, allowing the model to focus only on the most relevant features and ignore the rest (assigned 0 attention), improving interpretability.
clf_F.fit(
X_train=TAB_X_train.values, y_train=TAB_y_train.values,
eval_set=[(TAB_X_val.values, TAB_y_val.values)],
eval_name=['val'],
eval_metric=['auc'],
max_epochs=EPOCHS,
patience=20,
batch_size=BATCH_SIZE,
virtual_batch_size=128,
weights = weights
)
TabNet fit() Parameter Explanation¶
X_train=TAB_X_train.values: The training feature matrix as a NumPy array.y_train=TAB_y_train.values: The target labels for training as a NumPy array.eval_set=[(TAB_X_val.values, TAB_y_val.values)]: A list of tuples containing the validation set (features and labels). Used to monitor model performance during training.eval_name=['val']: The name associated with the evaluation set. Appears in the training logs to identify the validation metrics.eval_metric=['auc']: The evaluation metric to monitor. In this case, Area Under the ROC Curve (AUC) is used to track model performance.max_epochs=EPOCHS: The maximum number of epochs to train the model. Here, it's set using theEPOCHSconstant (e.g., 200).patience=20: Early stopping parameter. Training will stop if the validation metric does not improve for 20 consecutive epochs.batch_size=BATCH_SIZE: Number of samples processed in each training batch. Controlled via theBATCH_SIZEvariable (e.g., 64).virtual_batch_size=128: Used for Ghost Batch Normalization. Enables batch normalization over smaller subsets of the batch to simulate smaller batch behavior, improving generalization and stability.weights=weights: Class weights used to handle class imbalance. Ensures the model pays appropriate attention to minority classes by penalizing misclassification more heavily.
Best accuracy produced in Fold 4¶
- Accuracy: 0.8234
- AUC: 0.7604
- MCC: 0.6097
For evaluation of model metric and plots look into the report
Evaluation on Untrained data¶
# Predict probabilities and labels
# Predict
TAB_x = X_val.values
TAB_y = y_val
y_pred_proba = clf_F.predict_proba(TAB_x)[:, 1]
y_pred = clf_F.predict(TAB_x)
# AUC
auc_score = roc_auc_score(TAB_y, y_pred_proba)
print(f"\nFinal Validation AUC: {auc_score:.4f}")
# Classification report
print("Classification Report:")
print(classification_report(TAB_y, y_pred, digits=4))
# MCC
mcc = matthews_corrcoef(TAB_y, y_pred)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")
# Confusion matrix
cm = confusion_matrix(TAB_y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
from sklearn.metrics import roc_curve, auc
fpr, tpr, thresholds = roc_curve(TAB_y, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()
from sklearn.metrics import precision_recall_curve, average_precision_score
precision, recall, _ = precision_recall_curve(TAB_y, y_pred_proba)
avg_precision = average_precision_score(TAB_y, y_pred_proba)
plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_pred_proba[TAB_y == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
import seaborn as sns
# Explain method to get global feature importance
explain_matrix, masks = clf_F.explain(TAB_x) # TAB_x should be your validation features
# Create DataFrame for feature importances
feature_importances_df = pd.DataFrame({
'Feature': X_val.columns,
'Importance': explain_matrix.sum(axis=0)
}).sort_values(by='Importance', ascending=False)
# Plot
plt.figure(figsize=(8, 6))
sns.barplot(data=feature_importances_df, x='Importance', y='Feature', palette='viridis')
plt.title("TabNet Feature Importances (Global Explanation)")
plt.tight_layout()
plt.show()
Final Validation AUC: 0.5075
Classification Report:
precision recall f1-score support
0 0.8522 0.5678 0.6815 8487
1 0.1558 0.4475 0.2311 1513
accuracy 0.5496 10000
macro avg 0.5040 0.5076 0.4563 10000
weighted avg 0.7468 0.5496 0.6134 10000
Matthews Correlation Coefficient (MCC): 0.0110
Predicted probability histogram¶
- Class 0 (No Failure) predictions are mostly centered around 0.4–0.6.
- Class 1 (Failure) predictions overlap with Class 0 and also tend to peak around 0.4–0.5.
- This overlap indicates that the model is:
- Not highly confident in distinguishing failures from non-failures.
- Producing probabilities clustered near the middle range instead of near 0 or 1.
Feature importance¶
- Vibration_Levels has the highest importance. This suggests it's the most predictive feature for the Failure_Event target.
- Fuel_Consumption, Humidity, and Age_of_Asset also have high importance, indicating strong influence on failure predictions.
- Moderately useful features:
- Pressure has some influence but significantly less than the top four.
- Least used features:
- Features like Thermal_Stress, Temperature, Fuel_Efficiency, Asset_Type, Maintenance_History, and others at the bottom have very low to negligible importance in this model. It doesn't mean they are useless — just that TabNet didn't find them as helpful in the current data context.
Needs Maintainence¶
Reprepping Data¶
df["needs_maintenance"] = (df["Maintenance_History"] > 0).astype(int)
df.value_counts('needs_maintenance')
needs_maintenance 1 33303 0 16697 Name: count, dtype: int64
df.value_counts('Maintenance_History')
Maintenance_History 2 16714 0 16697 1 16589 Name: count, dtype: int64
df[["Asset_Type","needs_maintenance"]].value_counts()
Asset_Type needs_maintenance 2 1 11167 1 1 11069 0 1 11067 1 0 5640 0 0 5542 2 0 5515 Name: count, dtype: int64
Data Pre processing¶
# Features and target
X = df.drop(columns=['Maintenance_History','needs_maintenance'])
y = df['needs_maintenance']
# Define categorical and numerical features
cat_features = ['Asset_Type', 'Location']
num_features = [col for col in X.columns if col not in cat_features]
# Train-test split
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# --- Apply SMOTE to balance the classes in the training data ---
categorical_indices = [X_train.columns.get_loc(col) for col in cat_features]
print("Before SMOTE:", Counter(y_train))
# Apply SMOTENC (for categorical and continuous features)
smote_nc = SMOTENC(
categorical_features=categorical_indices,
sampling_strategy=0.7, # Make minority class 70% the size of the majority class
random_state=42
)
X_train_balanced, y_train_balanced = smote_nc.fit_resample(X_train, y_train)
# Convert back to DataFrame for further processing
X_train_final = pd.DataFrame(X_train_balanced, columns=X_train.columns)
y_train_final = pd.Series(y_train_balanced, name='needs_maintenance')
print("After SMOTE:", Counter(y_train_balanced))
# --- Scale numeric features after SMOTE ---
scaler = StandardScaler()
# Separate numerical columns for scaling
X_train_num = X_train_final[num_features]
X_val_num = X_val[num_features]
X_train_num_scaled = scaler.fit_transform(X_train_num)
X_val_num_scaled = scaler.transform(X_val_num)
# Convert back to DataFrame
X_train_num_scaled_df = pd.DataFrame(X_train_num_scaled, columns=num_features).reset_index(drop=True)
X_val_num_scaled_df = pd.DataFrame(X_val_num_scaled, columns=num_features).reset_index(drop=True)
# Prepare final training and validation sets by combining scaled numeric and encoded categorical features
X_train_final = pd.concat([X_train_num_scaled_df, X_train_final[cat_features].reset_index(drop=True)], axis=1)
X_val_final = pd.concat([X_val_num_scaled_df, X_val[cat_features].reset_index(drop=True)], axis=1)
Before SMOTE: Counter({1: 26642, 0: 13358})
After SMOTE: Counter({1: 26642, 0: 18649})
Random Forest Model¶
Grid search for best parameters¶
# STEP 2: Set up a Random Forest with basic pruning options
rf = RandomForestClassifier(random_state=42, n_jobs=-1, class_weight='balanced')
# STEP 3: Grid Search for best hyperparameters including pruning-related ones
param_grid = {
'n_estimators': [100],
'max_depth': [3, 5, 7, None], # Control tree size (pruning)
'min_samples_split': [2, 5, 10], # Prevent overgrowth
'min_samples_leaf': [1, 2, 4],
'max_features': ['sqrt', 'log2']
}
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=3, scoring='roc_auc', verbose=1)
grid_search.fit(X_train_balanced, y_train_balanced) # using unscalled data for Random Forest
# STEP 4: Get best model and evaluate
best_rf_M = grid_search.best_estimator_
y_pred = best_rf_M.predict(X_val)
y_prob = best_rf_M.predict_proba(X_val)[:, 1]
print("Best Params:", grid_search.best_params_)
print("AUC-ROC:", roc_auc_score(y_val, y_prob))
print(classification_report(y_val, y_pred, digits=4))
mcc = matthews_corrcoef(y_val, y_pred)
print(f"MCC: {mcc:.4f}")
# STEP 5: Plot one of the trees in the forest
plt.figure(figsize=(30, 20))
plot_tree(best_rf_M.estimators_[0],
feature_names=X.columns,
class_names=['Class 0', 'Class 1'],
filled=True,
rounded=True,
max_depth=3,
fontsize=10) # Only show top 3 levels for clarity
plt.title("Random Forest - Tree Visualization")
plt.show()
joblib.dump(best_rf_M, 'maintainence_needs_random_forest_model.joblib')
Fitting 3 folds for each of 72 candidates, totalling 216 fits
Best Params: {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}
AUC-ROC: 0.5045436419698883
precision recall f1-score support
0 0.3440 0.1518 0.2107 3339
1 0.6678 0.8548 0.7499 6661
accuracy 0.6201 10000
macro avg 0.5059 0.5033 0.4803 10000
weighted avg 0.5597 0.6201 0.5698 10000
MCC: 0.0089
['maintainence_needs_random_forest_model.joblib']
K-fold on best model¶
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import (
roc_auc_score,
classification_report,
matthews_corrcoef,
confusion_matrix,
ConfusionMatrixDisplay,
precision_recall_curve,
auc # re-importing packages so I can use it without any error
)
# Define Stratified K-Fold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Store metrics for summary
roc_auc_scores = []
# Loop over each fold
for fold, (train_idx, val_idx) in enumerate(cv.split(X_train_balanced, y_train_balanced), start=1):
X_train_fold, X_val_fold = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
y_train_fold, y_val_fold = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]
# Train the best model on this fold
best_rf_M.fit(X_train_fold, y_train_fold)
y_pred = best_rf_M.predict(X_val_fold)
y_proba = best_rf_M.predict_proba(X_val_fold)[:, 1]
# AUC for this fold
roc_auc = roc_auc_score(y_val_fold, y_proba)
roc_auc_scores.append(roc_auc)
# Print classification report
print(f"\nFold {fold} - Classification Report:")
print(classification_report(y_val_fold, y_pred, digits=4))
print(f"Fold {fold} - AUC-ROC: {roc_auc:.4f}")
mcc = matthews_corrcoef(y_val_fold, y_pred)
print(f"MCC: {mcc:.4f}")
# Confusion Matrix
cm = confusion_matrix(y_val_fold, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_val_fold, y_proba)
roc_auc_curve = auc(fpr, tpr)
# historgraph of predicted probabilities
plt.figure(figsize=(6, 4))
plt.hist(y_proba[y_val_fold == 0], bins=30, alpha=0.6, label='Class 0 (Does not require)', color='skyblue')
plt.hist(y_proba[y_val_fold == 1], bins=30, alpha=0.6, label='Class 1 (Needs_Maintainence)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
# Plot ROC Curve
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC AUC = {roc_auc:.4f}")
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--') # Diagonal line
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()
# Print average AUC over all folds
print("\nAverage AUC-ROC across folds:")
print(f"{np.mean(roc_auc_scores):.4f} ± {np.std(roc_auc_scores):.4f}")
Fold 1 - Classification Report:
precision recall f1-score support
0 0.5921 0.3051 0.4027 3730
1 0.6368 0.8529 0.7292 5329
accuracy 0.6273 9059
macro avg 0.6145 0.5790 0.5659 9059
weighted avg 0.6184 0.6273 0.5948 9059
Fold 1 - AUC-ROC: 0.6146
MCC: 0.1902
Fold 2 - Classification Report:
precision recall f1-score support
0 0.5830 0.3138 0.4079 3729
1 0.6371 0.8429 0.7257 5329
accuracy 0.6251 9058
macro avg 0.6100 0.5783 0.5668 9058
weighted avg 0.6148 0.6251 0.5949 9058
Fold 2 - AUC-ROC: 0.6106
MCC: 0.1857
Fold 3 - Classification Report:
precision recall f1-score support
0 0.5851 0.2968 0.3938 3730
1 0.6340 0.8527 0.7272 5328
accuracy 0.6238 9058
macro avg 0.6095 0.5747 0.5605 9058
weighted avg 0.6138 0.6238 0.5899 9058
Fold 3 - AUC-ROC: 0.6084
MCC: 0.1809
Fold 4 - Classification Report:
precision recall f1-score support
0 0.5976 0.3118 0.4098 3730
1 0.6391 0.8530 0.7307 5328
accuracy 0.6302 9058
macro avg 0.6183 0.5824 0.5703 9058
weighted avg 0.6220 0.6302 0.5986 9058
Fold 4 - AUC-ROC: 0.6091
MCC: 0.1975
Fold 5 - Classification Report:
precision recall f1-score support
0 0.5812 0.2965 0.3927 3730
1 0.6333 0.8504 0.7259 5328
accuracy 0.6223 9058
macro avg 0.6072 0.5735 0.5593 9058
weighted avg 0.6118 0.6223 0.5887 9058
Fold 5 - AUC-ROC: 0.6113
MCC: 0.1775
Average AUC-ROC across folds: 0.6108 ± 0.0022
TAB NET¶
# Configuration
TARGET_COL = 'Failure_Event'
CATEGORICAL_COLS = ['Asset_Type', 'Location']
RANDOM_STATE = 42
N_SPLITS = 5
EPOCHS = 100
BATCH_SIZE = 64
# Identify categorical feature indices and dimensions
cat_idxs = [X_train_balanced.columns.get_loc(col) for col in CATEGORICAL_COLS]
cat_dims = [int(df[col].nunique()) for col in CATEGORICAL_COLS]
cat_emb_dim = [min(50, (dim + 1) // 2) for dim in cat_dims]
# Cross-validation setup
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
fold = 1
auc_scores = []
# Compute class weights for this fold
classes = np.unique(y_train_balanced)
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train_balanced)
weights = dict(zip(classes, class_weights))
for train_idx, val_idx in skf.split(X_train_balanced, y_train_balanced):
print(f"\n==== Fold {fold} ====")
TAB_X_train, TAB_X_val = X_train_balanced.values[train_idx], X_train_balanced.values[val_idx]
TAB_y_train, TAB_y_val = y_train_balanced.values[train_idx], y_train_balanced.values[val_idx]
# Initialize and train the model
clf_M = TabNetClassifier(
n_d=32, n_a=32, n_steps=5, gamma=1.5,
cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=cat_emb_dim,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-2),
scheduler_params={"step_size":10, "gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax',
seed=RANDOM_STATE,
verbose=1
)
clf_M.fit(
X_train=TAB_X_train, y_train=TAB_y_train,
eval_set=[(TAB_X_val, TAB_y_val)],
eval_name=['val'],
eval_metric=['auc'],
max_epochs=EPOCHS,
patience=20,
batch_size=BATCH_SIZE,
virtual_batch_size=128,
weights=weights
)
# Evaluation
y_pred_proba = clf_M.predict_proba(TAB_X_val)[:, 1]
y_pred = clf_M.predict(TAB_X_val)
auc = roc_auc_score(TAB_y_val, y_pred_proba)
auc_scores.append(auc)
print(f"Fold {fold} AUC: {auc:.4f}")
print(classification_report(TAB_y_val, y_pred, digits=4))
mcc = matthews_corrcoef(TAB_y_val, y_pred)
print(f"MCC: {mcc:.4f}")
# Confusion matrix
cm = confusion_matrix(TAB_y_val, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
from sklearn.metrics import roc_curve, auc
fpr, tpr, thresholds = roc_curve(TAB_y_val, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()
plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y_val == 0], bins=30, alpha=0.6, label='Class 0 (Does not require)', color='skyblue')
plt.hist(y_pred_proba[TAB_y_val == 1], bins=30, alpha=0.6, label='Class 1 (Needs_Maintainence)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
fold += 1
# Final average AUC
print(f"\n==== Cross-Validation Complete ====")
print(f"Mean AUC: {np.mean(auc_scores):.4f} | Std AUC: {np.std(auc_scores):.4f}")
clf_M.save_model("Needs_Maintainence_tabnet_model")
==== Fold 1 ====
epoch 0 | loss: 0.73862 | val_auc: 0.51499 | 0:00:10s
epoch 1 | loss: 0.69416 | val_auc: 0.53331 | 0:00:20s
epoch 2 | loss: 0.69339 | val_auc: 0.53815 | 0:00:31s
epoch 3 | loss: 0.69292 | val_auc: 0.53089 | 0:00:41s
epoch 4 | loss: 0.69277 | val_auc: 0.54502 | 0:00:50s
epoch 5 | loss: 0.69199 | val_auc: 0.53341 | 0:01:00s
epoch 6 | loss: 0.69283 | val_auc: 0.53721 | 0:01:09s
epoch 7 | loss: 0.69194 | val_auc: 0.54032 | 0:01:21s
epoch 8 | loss: 0.69202 | val_auc: 0.5424 | 0:01:31s
epoch 9 | loss: 0.69166 | val_auc: 0.54612 | 0:01:51s
epoch 10 | loss: 0.69144 | val_auc: 0.53952 | 0:02:01s
epoch 11 | loss: 0.6909 | val_auc: 0.55133 | 0:02:12s
epoch 12 | loss: 0.69087 | val_auc: 0.5454 | 0:02:25s
epoch 13 | loss: 0.69113 | val_auc: 0.53313 | 0:02:36s
epoch 14 | loss: 0.69073 | val_auc: 0.54081 | 0:02:46s
epoch 15 | loss: 0.69203 | val_auc: 0.54166 | 0:02:56s
epoch 16 | loss: 0.69194 | val_auc: 0.536 | 0:03:06s
epoch 17 | loss: 0.69139 | val_auc: 0.54411 | 0:03:16s
epoch 18 | loss: 0.69065 | val_auc: 0.54168 | 0:03:26s
epoch 19 | loss: 0.69148 | val_auc: 0.54292 | 0:03:36s
epoch 20 | loss: 0.69137 | val_auc: 0.54744 | 0:03:46s
epoch 21 | loss: 0.69066 | val_auc: 0.54813 | 0:03:55s
epoch 22 | loss: 0.69123 | val_auc: 0.54103 | 0:04:05s
epoch 23 | loss: 0.6911 | val_auc: 0.54438 | 0:04:15s
epoch 24 | loss: 0.69162 | val_auc: 0.54191 | 0:04:24s
epoch 25 | loss: 0.69145 | val_auc: 0.54712 | 0:04:34s
epoch 26 | loss: 0.6899 | val_auc: 0.54965 | 0:04:43s
epoch 27 | loss: 0.69123 | val_auc: 0.54522 | 0:04:53s
epoch 28 | loss: 0.69039 | val_auc: 0.5529 | 0:05:02s
epoch 29 | loss: 0.69031 | val_auc: 0.54752 | 0:05:11s
epoch 30 | loss: 0.69056 | val_auc: 0.54231 | 0:05:20s
epoch 31 | loss: 0.69079 | val_auc: 0.54904 | 0:05:30s
epoch 32 | loss: 0.69075 | val_auc: 0.54459 | 0:05:39s
epoch 33 | loss: 0.69099 | val_auc: 0.53882 | 0:05:50s
epoch 34 | loss: 0.6911 | val_auc: 0.54586 | 0:06:00s
epoch 35 | loss: 0.69131 | val_auc: 0.52696 | 0:06:11s
epoch 36 | loss: 0.691 | val_auc: 0.5423 | 0:06:21s
epoch 37 | loss: 0.69104 | val_auc: 0.54299 | 0:06:30s
epoch 38 | loss: 0.69103 | val_auc: 0.54231 | 0:06:40s
epoch 39 | loss: 0.69053 | val_auc: 0.53998 | 0:06:50s
epoch 40 | loss: 0.69161 | val_auc: 0.5415 | 0:06:59s
epoch 41 | loss: 0.69078 | val_auc: 0.54198 | 0:07:09s
epoch 42 | loss: 0.69055 | val_auc: 0.54156 | 0:07:19s
epoch 43 | loss: 0.6903 | val_auc: 0.55016 | 0:07:28s
epoch 44 | loss: 0.69023 | val_auc: 0.55124 | 0:07:37s
epoch 45 | loss: 0.69046 | val_auc: 0.55141 | 0:07:47s
epoch 46 | loss: 0.69092 | val_auc: 0.54714 | 0:07:56s
epoch 47 | loss: 0.69081 | val_auc: 0.54517 | 0:08:05s
epoch 48 | loss: 0.69081 | val_auc: 0.54334 | 0:08:15s
Early stopping occurred at epoch 48 with best_epoch = 28 and best_val_auc = 0.5529
Fold 1 AUC: 0.5529
precision recall f1-score support
0 0.4578 0.4670 0.4624 3730
1 0.6216 0.6129 0.6172 5329
accuracy 0.5528 9059
macro avg 0.5397 0.5399 0.5398 9059
weighted avg 0.5542 0.5528 0.5535 9059
MCC: 0.0797
==== Fold 2 ====
epoch 0 | loss: 0.74367 | val_auc: 0.51752 | 0:00:09s
epoch 1 | loss: 0.69457 | val_auc: 0.52189 | 0:00:18s
epoch 2 | loss: 0.69358 | val_auc: 0.5221 | 0:00:27s
epoch 3 | loss: 0.69297 | val_auc: 0.52514 | 0:00:38s
epoch 4 | loss: 0.69195 | val_auc: 0.5362 | 0:00:48s
epoch 5 | loss: 0.69141 | val_auc: 0.53066 | 0:00:58s
epoch 6 | loss: 0.69216 | val_auc: 0.54248 | 0:01:08s
epoch 7 | loss: 0.6919 | val_auc: 0.54002 | 0:01:18s
epoch 8 | loss: 0.69249 | val_auc: 0.53261 | 0:01:27s
epoch 9 | loss: 0.69238 | val_auc: 0.52982 | 0:01:37s
epoch 10 | loss: 0.69158 | val_auc: 0.53861 | 0:01:46s
epoch 11 | loss: 0.6917 | val_auc: 0.534 | 0:01:56s
epoch 12 | loss: 0.69119 | val_auc: 0.53371 | 0:02:06s
epoch 13 | loss: 0.6917 | val_auc: 0.53483 | 0:02:15s
epoch 14 | loss: 0.69118 | val_auc: 0.53192 | 0:02:25s
epoch 15 | loss: 0.69228 | val_auc: 0.53583 | 0:02:34s
epoch 16 | loss: 0.69103 | val_auc: 0.52722 | 0:02:44s
epoch 17 | loss: 0.69155 | val_auc: 0.52847 | 0:02:53s
epoch 18 | loss: 0.69104 | val_auc: 0.53745 | 0:03:03s
epoch 19 | loss: 0.69131 | val_auc: 0.5522 | 0:03:12s
epoch 20 | loss: 0.69 | val_auc: 0.55333 | 0:03:22s
epoch 21 | loss: 0.6905 | val_auc: 0.53932 | 0:03:32s
epoch 22 | loss: 0.69032 | val_auc: 0.5469 | 0:03:42s
epoch 23 | loss: 0.69092 | val_auc: 0.54167 | 0:03:51s
epoch 24 | loss: 0.69007 | val_auc: 0.54281 | 0:04:01s
epoch 25 | loss: 0.69042 | val_auc: 0.55122 | 0:04:11s
epoch 26 | loss: 0.6901 | val_auc: 0.54927 | 0:04:20s
epoch 27 | loss: 0.69133 | val_auc: 0.54574 | 0:04:29s
epoch 28 | loss: 0.69108 | val_auc: 0.54386 | 0:04:39s
epoch 29 | loss: 0.69033 | val_auc: 0.54586 | 0:04:48s
epoch 30 | loss: 0.68972 | val_auc: 0.54617 | 0:04:58s
epoch 31 | loss: 0.69002 | val_auc: 0.54678 | 0:05:07s
epoch 32 | loss: 0.68884 | val_auc: 0.54365 | 0:05:17s
epoch 33 | loss: 0.69046 | val_auc: 0.5424 | 0:05:26s
epoch 34 | loss: 0.69054 | val_auc: 0.54204 | 0:05:35s
epoch 35 | loss: 0.68904 | val_auc: 0.53871 | 0:05:45s
epoch 36 | loss: 0.68944 | val_auc: 0.53307 | 0:05:54s
epoch 37 | loss: 0.69127 | val_auc: 0.53761 | 0:06:04s
epoch 38 | loss: 0.69016 | val_auc: 0.53836 | 0:06:14s
epoch 39 | loss: 0.68898 | val_auc: 0.543 | 0:06:24s
epoch 40 | loss: 0.69073 | val_auc: 0.52934 | 0:06:34s
Early stopping occurred at epoch 40 with best_epoch = 20 and best_val_auc = 0.55333
Fold 2 AUC: 0.5533
precision recall f1-score support
0 0.4479 0.5329 0.4867 3729
1 0.6231 0.5404 0.5788 5329
accuracy 0.5373 9058
macro avg 0.5355 0.5366 0.5328 9058
weighted avg 0.5510 0.5373 0.5409 9058
MCC: 0.0722
==== Fold 3 ====
epoch 0 | loss: 0.74304 | val_auc: 0.52747 | 0:00:11s
epoch 1 | loss: 0.69485 | val_auc: 0.51745 | 0:00:22s
epoch 2 | loss: 0.69377 | val_auc: 0.53443 | 0:00:34s
epoch 3 | loss: 0.69338 | val_auc: 0.52869 | 0:00:47s
epoch 4 | loss: 0.69234 | val_auc: 0.53622 | 0:00:59s
epoch 5 | loss: 0.69235 | val_auc: 0.5275 | 0:01:12s
epoch 6 | loss: 0.69239 | val_auc: 0.52821 | 0:01:25s
epoch 7 | loss: 0.69255 | val_auc: 0.53736 | 0:01:36s
epoch 8 | loss: 0.69244 | val_auc: 0.53165 | 0:01:46s
epoch 9 | loss: 0.6927 | val_auc: 0.53163 | 0:01:57s
epoch 10 | loss: 0.69191 | val_auc: 0.53879 | 0:02:07s
epoch 11 | loss: 0.69221 | val_auc: 0.54915 | 0:02:17s
epoch 12 | loss: 0.69166 | val_auc: 0.54173 | 0:02:27s
epoch 13 | loss: 0.69172 | val_auc: 0.54136 | 0:02:38s
epoch 14 | loss: 0.69246 | val_auc: 0.53891 | 0:02:48s
epoch 15 | loss: 0.69177 | val_auc: 0.54401 | 0:02:58s
epoch 16 | loss: 0.69112 | val_auc: 0.54718 | 0:03:08s
epoch 17 | loss: 0.69115 | val_auc: 0.54855 | 0:03:18s
epoch 18 | loss: 0.69055 | val_auc: 0.5427 | 0:03:29s
epoch 19 | loss: 0.69139 | val_auc: 0.54593 | 0:03:39s
epoch 20 | loss: 0.69151 | val_auc: 0.54426 | 0:03:49s
epoch 21 | loss: 0.69124 | val_auc: 0.5407 | 0:03:59s
epoch 22 | loss: 0.69182 | val_auc: 0.5373 | 0:04:09s
epoch 23 | loss: 0.69217 | val_auc: 0.53434 | 0:04:19s
epoch 24 | loss: 0.69191 | val_auc: 0.54526 | 0:04:28s
epoch 25 | loss: 0.69145 | val_auc: 0.53809 | 0:04:38s
epoch 26 | loss: 0.69112 | val_auc: 0.53639 | 0:04:47s
epoch 27 | loss: 0.69169 | val_auc: 0.53389 | 0:04:57s
epoch 28 | loss: 0.6912 | val_auc: 0.53249 | 0:05:07s
epoch 29 | loss: 0.69153 | val_auc: 0.53651 | 0:05:16s
epoch 30 | loss: 0.69081 | val_auc: 0.52509 | 0:05:25s
epoch 31 | loss: 0.69062 | val_auc: 0.52946 | 0:05:35s
Early stopping occurred at epoch 31 with best_epoch = 11 and best_val_auc = 0.54915
Fold 3 AUC: 0.5491
precision recall f1-score support
0 0.4348 0.6461 0.5198 3730
1 0.6245 0.4120 0.4964 5328
accuracy 0.5084 9058
macro avg 0.5296 0.5290 0.5081 9058
weighted avg 0.5464 0.5084 0.5061 9058
MCC: 0.0587
==== Fold 4 ====
epoch 0 | loss: 0.73705 | val_auc: 0.51952 | 0:00:10s
epoch 1 | loss: 0.69329 | val_auc: 0.53089 | 0:00:20s
epoch 2 | loss: 0.69252 | val_auc: 0.54646 | 0:00:31s
epoch 3 | loss: 0.69228 | val_auc: 0.54065 | 0:00:41s
epoch 4 | loss: 0.69184 | val_auc: 0.534 | 0:00:51s
epoch 5 | loss: 0.69193 | val_auc: 0.54657 | 0:01:01s
epoch 6 | loss: 0.69151 | val_auc: 0.55287 | 0:01:10s
epoch 7 | loss: 0.69107 | val_auc: 0.54494 | 0:01:20s
epoch 8 | loss: 0.69122 | val_auc: 0.52761 | 0:01:29s
epoch 9 | loss: 0.69107 | val_auc: 0.55175 | 0:01:38s
epoch 10 | loss: 0.68997 | val_auc: 0.54981 | 0:01:48s
epoch 11 | loss: 0.69011 | val_auc: 0.54628 | 0:01:58s
epoch 12 | loss: 0.69029 | val_auc: 0.54834 | 0:02:07s
epoch 13 | loss: 0.68975 | val_auc: 0.54966 | 0:02:15s
epoch 14 | loss: 0.68997 | val_auc: 0.54721 | 0:02:25s
epoch 15 | loss: 0.69032 | val_auc: 0.55026 | 0:02:36s
epoch 16 | loss: 0.69028 | val_auc: 0.55701 | 0:02:47s
epoch 17 | loss: 0.69031 | val_auc: 0.54663 | 0:02:57s
epoch 18 | loss: 0.69004 | val_auc: 0.55328 | 0:03:08s
epoch 19 | loss: 0.68989 | val_auc: 0.55197 | 0:03:19s
epoch 20 | loss: 0.6898 | val_auc: 0.54563 | 0:03:30s
epoch 21 | loss: 0.68916 | val_auc: 0.55396 | 0:03:40s
epoch 22 | loss: 0.69026 | val_auc: 0.55576 | 0:03:51s
epoch 23 | loss: 0.68898 | val_auc: 0.55987 | 0:04:02s
epoch 24 | loss: 0.68901 | val_auc: 0.5621 | 0:04:14s
epoch 25 | loss: 0.68924 | val_auc: 0.56154 | 0:04:25s
epoch 26 | loss: 0.68936 | val_auc: 0.56151 | 0:04:34s
epoch 27 | loss: 0.68945 | val_auc: 0.55443 | 0:04:44s
epoch 28 | loss: 0.68945 | val_auc: 0.55818 | 0:04:53s
epoch 29 | loss: 0.68815 | val_auc: 0.55186 | 0:05:02s
epoch 30 | loss: 0.68838 | val_auc: 0.56158 | 0:05:11s
epoch 31 | loss: 0.68856 | val_auc: 0.56195 | 0:05:21s
epoch 32 | loss: 0.68918 | val_auc: 0.56096 | 0:05:30s
epoch 33 | loss: 0.68899 | val_auc: 0.56063 | 0:05:40s
epoch 34 | loss: 0.68933 | val_auc: 0.55781 | 0:05:49s
epoch 35 | loss: 0.68948 | val_auc: 0.55896 | 0:06:00s
epoch 36 | loss: 0.68867 | val_auc: 0.56113 | 0:06:12s
epoch 37 | loss: 0.68971 | val_auc: 0.56398 | 0:06:21s
epoch 38 | loss: 0.68879 | val_auc: 0.55588 | 0:06:31s
epoch 39 | loss: 0.6882 | val_auc: 0.5618 | 0:06:40s
epoch 40 | loss: 0.68834 | val_auc: 0.55574 | 0:06:49s
epoch 41 | loss: 0.68724 | val_auc: 0.56463 | 0:06:58s
epoch 42 | loss: 0.68692 | val_auc: 0.56536 | 0:07:07s
epoch 43 | loss: 0.68593 | val_auc: 0.56375 | 0:07:17s
epoch 44 | loss: 0.68509 | val_auc: 0.5708 | 0:07:27s
epoch 45 | loss: 0.68516 | val_auc: 0.56854 | 0:07:37s
epoch 46 | loss: 0.68481 | val_auc: 0.56642 | 0:07:49s
epoch 47 | loss: 0.68374 | val_auc: 0.57037 | 0:08:02s
epoch 48 | loss: 0.6832 | val_auc: 0.5712 | 0:08:13s
epoch 49 | loss: 0.68235 | val_auc: 0.56695 | 0:08:24s
epoch 50 | loss: 0.68079 | val_auc: 0.5731 | 0:08:35s
epoch 51 | loss: 0.67951 | val_auc: 0.56181 | 0:08:45s
epoch 52 | loss: 0.67734 | val_auc: 0.56499 | 0:08:56s
epoch 53 | loss: 0.6731 | val_auc: 0.58119 | 0:09:06s
epoch 54 | loss: 0.66857 | val_auc: 0.58393 | 0:09:17s
epoch 55 | loss: 0.66559 | val_auc: 0.58995 | 0:09:28s
epoch 56 | loss: 0.66147 | val_auc: 0.57975 | 0:09:38s
epoch 57 | loss: 0.65804 | val_auc: 0.59748 | 0:09:48s
epoch 58 | loss: 0.65249 | val_auc: 0.59459 | 0:09:58s
epoch 59 | loss: 0.64925 | val_auc: 0.58605 | 0:10:10s
epoch 60 | loss: 0.64754 | val_auc: 0.59586 | 0:10:21s
epoch 61 | loss: 0.6474 | val_auc: 0.58105 | 0:10:33s
epoch 62 | loss: 0.6467 | val_auc: 0.60681 | 0:10:43s
epoch 63 | loss: 0.64527 | val_auc: 0.59729 | 0:10:54s
epoch 64 | loss: 0.64216 | val_auc: 0.58857 | 0:11:04s
epoch 65 | loss: 0.64467 | val_auc: 0.57536 | 0:11:15s
epoch 66 | loss: 0.64411 | val_auc: 0.59619 | 0:11:25s
epoch 67 | loss: 0.64353 | val_auc: 0.57535 | 0:11:35s
epoch 68 | loss: 0.64291 | val_auc: 0.60058 | 0:11:45s
epoch 69 | loss: 0.6421 | val_auc: 0.59725 | 0:11:55s
epoch 70 | loss: 0.64248 | val_auc: 0.59396 | 0:12:06s
epoch 71 | loss: 0.64041 | val_auc: 0.58804 | 0:12:16s
epoch 72 | loss: 0.6408 | val_auc: 0.60002 | 0:12:26s
epoch 73 | loss: 0.64118 | val_auc: 0.57349 | 0:12:36s
epoch 74 | loss: 0.63999 | val_auc: 0.59896 | 0:12:46s
epoch 75 | loss: 0.64145 | val_auc: 0.58056 | 0:12:55s
epoch 76 | loss: 0.63728 | val_auc: 0.60125 | 0:13:04s
epoch 77 | loss: 0.63837 | val_auc: 0.60125 | 0:13:14s
epoch 78 | loss: 0.63685 | val_auc: 0.5927 | 0:13:24s
epoch 79 | loss: 0.63236 | val_auc: 0.61108 | 0:13:35s
epoch 80 | loss: 0.62904 | val_auc: 0.58998 | 0:13:45s
epoch 81 | loss: 0.62744 | val_auc: 0.60016 | 0:13:55s
epoch 82 | loss: 0.62566 | val_auc: 0.60276 | 0:14:05s
epoch 83 | loss: 0.62634 | val_auc: 0.61056 | 0:14:15s
epoch 84 | loss: 0.62373 | val_auc: 0.58483 | 0:14:26s
epoch 85 | loss: 0.62367 | val_auc: 0.60493 | 0:14:37s
epoch 86 | loss: 0.62276 | val_auc: 0.58662 | 0:14:48s
epoch 87 | loss: 0.62387 | val_auc: 0.60769 | 0:14:58s
epoch 88 | loss: 0.62379 | val_auc: 0.60117 | 0:15:07s
epoch 89 | loss: 0.62122 | val_auc: 0.6004 | 0:15:17s
epoch 90 | loss: 0.62167 | val_auc: 0.6048 | 0:15:27s
epoch 91 | loss: 0.61774 | val_auc: 0.60171 | 0:15:36s
epoch 92 | loss: 0.62124 | val_auc: 0.59599 | 0:15:45s
epoch 93 | loss: 0.62162 | val_auc: 0.58103 | 0:15:55s
epoch 94 | loss: 0.61876 | val_auc: 0.61844 | 0:16:05s
epoch 95 | loss: 0.62024 | val_auc: 0.56773 | 0:16:15s
epoch 96 | loss: 0.61768 | val_auc: 0.59562 | 0:16:25s
epoch 97 | loss: 0.6174 | val_auc: 0.61055 | 0:16:35s
epoch 98 | loss: 0.61906 | val_auc: 0.60331 | 0:16:46s
epoch 99 | loss: 0.6201 | val_auc: 0.6006 | 0:16:57s
Stop training because you reached max_epochs = 100 with best_epoch = 94 and best_val_auc = 0.61844
Fold 4 AUC: 0.6184
precision recall f1-score support
0 0.4866 0.5244 0.5048 3730
1 0.6479 0.6126 0.6298 5328
accuracy 0.5763 9058
macro avg 0.5672 0.5685 0.5673 9058
weighted avg 0.5815 0.5763 0.5783 9058
MCC: 0.1357
==== Fold 5 ====
epoch 0 | loss: 0.7381 | val_auc: 0.53926 | 0:00:09s
epoch 1 | loss: 0.69417 | val_auc: 0.54001 | 0:00:18s
epoch 2 | loss: 0.69334 | val_auc: 0.51119 | 0:00:28s
epoch 3 | loss: 0.69369 | val_auc: 0.52295 | 0:00:37s
epoch 4 | loss: 0.69285 | val_auc: 0.52012 | 0:00:46s
epoch 5 | loss: 0.69289 | val_auc: 0.53658 | 0:00:55s
epoch 6 | loss: 0.69219 | val_auc: 0.5421 | 0:01:04s
epoch 7 | loss: 0.69106 | val_auc: 0.54826 | 0:01:14s
epoch 8 | loss: 0.69029 | val_auc: 0.54747 | 0:01:24s
epoch 9 | loss: 0.69032 | val_auc: 0.54592 | 0:01:39s
epoch 10 | loss: 0.68874 | val_auc: 0.55297 | 0:01:51s
epoch 11 | loss: 0.69012 | val_auc: 0.54611 | 0:02:02s
epoch 12 | loss: 0.68962 | val_auc: 0.55438 | 0:02:14s
epoch 13 | loss: 0.69062 | val_auc: 0.54867 | 0:02:25s
epoch 14 | loss: 0.68956 | val_auc: 0.5475 | 0:02:39s
epoch 15 | loss: 0.69103 | val_auc: 0.55061 | 0:02:51s
epoch 16 | loss: 0.6898 | val_auc: 0.55431 | 0:03:03s
epoch 17 | loss: 0.68735 | val_auc: 0.55323 | 0:03:15s
epoch 18 | loss: 0.68786 | val_auc: 0.55427 | 0:03:29s
epoch 19 | loss: 0.6885 | val_auc: 0.55725 | 0:03:40s
epoch 20 | loss: 0.68761 | val_auc: 0.56605 | 0:03:52s
epoch 21 | loss: 0.68725 | val_auc: 0.56148 | 0:04:04s
epoch 22 | loss: 0.68766 | val_auc: 0.55726 | 0:04:14s
epoch 23 | loss: 0.68739 | val_auc: 0.55932 | 0:04:24s
epoch 24 | loss: 0.68669 | val_auc: 0.56105 | 0:04:34s
epoch 25 | loss: 0.68812 | val_auc: 0.55484 | 0:04:43s
epoch 26 | loss: 0.68616 | val_auc: 0.56329 | 0:04:52s
epoch 27 | loss: 0.68647 | val_auc: 0.56096 | 0:05:01s
epoch 28 | loss: 0.6869 | val_auc: 0.56062 | 0:05:10s
epoch 29 | loss: 0.6858 | val_auc: 0.56036 | 0:05:18s
epoch 30 | loss: 0.68557 | val_auc: 0.56423 | 0:05:27s
epoch 31 | loss: 0.68521 | val_auc: 0.55339 | 0:05:35s
epoch 32 | loss: 0.68604 | val_auc: 0.56167 | 0:05:43s
epoch 33 | loss: 0.68492 | val_auc: 0.55643 | 0:05:51s
epoch 34 | loss: 0.6846 | val_auc: 0.56229 | 0:05:59s
epoch 35 | loss: 0.68464 | val_auc: 0.55972 | 0:06:07s
epoch 36 | loss: 0.68277 | val_auc: 0.55484 | 0:06:15s
epoch 37 | loss: 0.68264 | val_auc: 0.56243 | 0:06:23s
epoch 38 | loss: 0.68409 | val_auc: 0.56235 | 0:06:31s
epoch 39 | loss: 0.68236 | val_auc: 0.55643 | 0:06:38s
epoch 40 | loss: 0.68221 | val_auc: 0.56639 | 0:06:45s
epoch 41 | loss: 0.68162 | val_auc: 0.56583 | 0:06:53s
epoch 42 | loss: 0.67926 | val_auc: 0.56331 | 0:07:00s
epoch 43 | loss: 0.6806 | val_auc: 0.56547 | 0:07:07s
epoch 44 | loss: 0.68012 | val_auc: 0.56219 | 0:07:15s
epoch 45 | loss: 0.67786 | val_auc: 0.56578 | 0:07:22s
epoch 46 | loss: 0.67827 | val_auc: 0.56375 | 0:07:31s
epoch 47 | loss: 0.67381 | val_auc: 0.57245 | 0:07:40s
epoch 48 | loss: 0.67332 | val_auc: 0.57125 | 0:07:49s
epoch 49 | loss: 0.6708 | val_auc: 0.57653 | 0:07:57s
epoch 50 | loss: 0.6678 | val_auc: 0.57729 | 0:08:06s
epoch 51 | loss: 0.66534 | val_auc: 0.58197 | 0:08:14s
epoch 52 | loss: 0.66391 | val_auc: 0.58182 | 0:08:24s
epoch 53 | loss: 0.66229 | val_auc: 0.58137 | 0:08:33s
epoch 54 | loss: 0.65908 | val_auc: 0.59146 | 0:08:42s
epoch 55 | loss: 0.65998 | val_auc: 0.59352 | 0:08:51s
epoch 56 | loss: 0.65774 | val_auc: 0.5932 | 0:09:00s
epoch 57 | loss: 0.65654 | val_auc: 0.59505 | 0:09:09s
epoch 58 | loss: 0.65287 | val_auc: 0.59746 | 0:09:18s
epoch 59 | loss: 0.64709 | val_auc: 0.60005 | 0:09:26s
epoch 60 | loss: 0.64252 | val_auc: 0.59609 | 0:09:35s
epoch 61 | loss: 0.63922 | val_auc: 0.60742 | 0:09:45s
epoch 62 | loss: 0.63946 | val_auc: 0.6078 | 0:09:54s
epoch 63 | loss: 0.63713 | val_auc: 0.60222 | 0:10:03s
epoch 64 | loss: 0.63655 | val_auc: 0.60803 | 0:10:11s
epoch 65 | loss: 0.63522 | val_auc: 0.6067 | 0:10:19s
epoch 66 | loss: 0.63349 | val_auc: 0.60649 | 0:10:27s
epoch 67 | loss: 0.63384 | val_auc: 0.60731 | 0:10:36s
epoch 68 | loss: 0.63134 | val_auc: 0.59559 | 0:10:45s
epoch 69 | loss: 0.63047 | val_auc: 0.58947 | 0:10:53s
epoch 70 | loss: 0.63179 | val_auc: 0.60461 | 0:11:00s
epoch 71 | loss: 0.62901 | val_auc: 0.60928 | 0:11:08s
epoch 72 | loss: 0.62975 | val_auc: 0.60007 | 0:11:17s
epoch 73 | loss: 0.62767 | val_auc: 0.60926 | 0:11:26s
epoch 74 | loss: 0.62735 | val_auc: 0.60034 | 0:11:34s
epoch 75 | loss: 0.62829 | val_auc: 0.60422 | 0:11:43s
epoch 76 | loss: 0.62898 | val_auc: 0.60282 | 0:11:51s
epoch 77 | loss: 0.62904 | val_auc: 0.59976 | 0:12:00s
epoch 78 | loss: 0.6247 | val_auc: 0.59464 | 0:12:09s
epoch 79 | loss: 0.62823 | val_auc: 0.6074 | 0:12:18s
epoch 80 | loss: 0.62719 | val_auc: 0.60294 | 0:12:26s
epoch 81 | loss: 0.62525 | val_auc: 0.61254 | 0:12:35s
epoch 82 | loss: 0.62501 | val_auc: 0.60337 | 0:12:44s
epoch 83 | loss: 0.62337 | val_auc: 0.61158 | 0:12:52s
epoch 84 | loss: 0.62474 | val_auc: 0.60949 | 0:13:00s
epoch 85 | loss: 0.62542 | val_auc: 0.60143 | 0:13:09s
epoch 86 | loss: 0.62309 | val_auc: 0.61275 | 0:13:17s
epoch 87 | loss: 0.62445 | val_auc: 0.61873 | 0:13:26s
epoch 88 | loss: 0.62233 | val_auc: 0.61826 | 0:13:34s
epoch 89 | loss: 0.62163 | val_auc: 0.62097 | 0:13:43s
epoch 90 | loss: 0.6205 | val_auc: 0.62121 | 0:13:51s
epoch 91 | loss: 0.6206 | val_auc: 0.61683 | 0:14:00s
epoch 92 | loss: 0.62121 | val_auc: 0.61265 | 0:14:09s
epoch 93 | loss: 0.62048 | val_auc: 0.60305 | 0:14:19s
epoch 94 | loss: 0.6186 | val_auc: 0.60847 | 0:14:28s
epoch 95 | loss: 0.62331 | val_auc: 0.60684 | 0:14:38s
epoch 96 | loss: 0.61866 | val_auc: 0.60686 | 0:14:48s
epoch 97 | loss: 0.62032 | val_auc: 0.60096 | 0:14:59s
epoch 98 | loss: 0.62132 | val_auc: 0.61791 | 0:15:09s
epoch 99 | loss: 0.62071 | val_auc: 0.60238 | 0:15:18s
Stop training because you reached max_epochs = 100 with best_epoch = 90 and best_val_auc = 0.62121
Fold 5 AUC: 0.6212
precision recall f1-score support
0 0.9413 0.2193 0.3557 3730
1 0.6444 0.9904 0.7808 5328
accuracy 0.6729 9058
macro avg 0.7929 0.6049 0.5683 9058
weighted avg 0.7667 0.6729 0.6058 9058
MCC: 0.3505
==== Cross-Validation Complete ==== Mean AUC: 0.5790 | Std AUC: 0.0334 Successfully saved model at Needs_Maintainence_tabnet_model.zip
'Needs_Maintainence_tabnet_model.zip'
Extra evaluation on untrained data¶
# Predict probabilities and labels
# Predict
TAB_x = TAB_X_val
TAB_y = TAB_y_val
y_pred_proba = clf_M.predict_proba(TAB_x)[:, 1]
y_pred = clf_M.predict(TAB_x)
# AUC
auc_score = roc_auc_score(TAB_y, y_pred_proba)
print(f"\nFinal Validation AUC: {auc_score:.4f}")
# Classification report
print("Classification Report:")
print(classification_report(TAB_y, y_pred, digits=4))
# MCC
mcc = matthews_corrcoef(TAB_y, y_pred)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")
# Confusion matrix
cm = confusion_matrix(TAB_y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()
from sklearn.metrics import roc_curve, auc
fpr, tpr, thresholds = roc_curve(TAB_y, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()
plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_pred_proba[TAB_y == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
import seaborn as sns
# Explain method to get global feature importance
explain_matrix, masks = clf_M.explain(TAB_x) # TAB_x should be your validation features
# Create DataFrame for feature importances
feature_importances_df = pd.DataFrame({
'Feature': X_val.columns,
'Importance': explain_matrix.sum(axis=0)
}).sort_values(by='Importance', ascending=False)
# Plot
plt.figure(figsize=(8, 6))
sns.barplot(data=feature_importances_df, x='Importance', y='Feature', palette='viridis')
plt.title("TabNet Feature Importances (Global Explanation)")
plt.tight_layout()
plt.show()
Final Validation AUC: 0.6212
Classification Report:
precision recall f1-score support
0 0.9413 0.2193 0.3557 3730
1 0.6444 0.9904 0.7808 5328
accuracy 0.6729 9058
macro avg 0.7929 0.6049 0.5683 9058
weighted avg 0.7667 0.6729 0.6058 9058
Matthews Correlation Coefficient (MCC): 0.3505